Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Regression Rescorer #52059

Closed
wants to merge 16 commits into from
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;

public class EmptyConfigUpdate implements InferenceConfigUpdate {
public static final ParseField NAME = new ParseField("empty");

private static final ObjectParser<EmptyConfigUpdate, Void> PARSER =
new ObjectParser<>(NAME.getPreferredName(), EmptyConfigUpdate::new);

public static EmptyConfigUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public EmptyConfigUpdate() {
}

public EmptyConfigUpdate(StreamInput in) {
}

@Override
public InferenceConfig apply(InferenceConfig originalConfig) {
return originalConfig;
}

@Override
public InferenceConfig toConfig() {
return RegressionConfig.EMPTY_PARAMS;
}

@Override
public boolean isSupported(InferenceConfig config) {
return true;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
}

@Override
public void writeTo(StreamOutput out) throws IOException {

}

@Override
public String getName() {
return NAME.getPreferredName();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
// Equal if o is not null and the same class
return (o == null || getClass() != o.getClass()) == false;
}

@Override
public int hashCode() {
return super.hashCode();
}
}
3 changes: 2 additions & 1 deletion x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies {
// bring in machine learning rest test suite
restResources {
restApi {
includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest'
includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest', 'bulk'
includeXpack 'ml', 'cat'
}
restTests {
Expand Down Expand Up @@ -206,6 +206,7 @@ integTest.runner {
'ml/validate_detector/Test invalid detector',
'ml/delete_forecast/Test delete on _all forecasts not allow no forecasts',
'ml/delete_forecast/Test delete forecast on missing forecast',
'ml/rescore/Test rescore',
'ml/set_upgrade_mode/Attempt to open job when upgrade_mode is enabled',
'ml/set_upgrade_mode/Setting upgrade_mode to enabled',
'ml/set_upgrade_mode/Setting upgrade mode to disabled from enabled',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.repositories.RepositoriesService;
import org.elasticsearch.rest.RestController;
Expand Down Expand Up @@ -229,6 +230,7 @@
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.inference.search.InferenceRescorerBuilder;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
Expand Down Expand Up @@ -338,7 +340,7 @@
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;

public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
Expand All @@ -362,6 +364,14 @@ protected Setting<Boolean> roleSetting() {

};

@Override
public List<RescorerSpec<?>> getRescorers() {
return Collections.singletonList(
new RescorerSpec<>(InferenceRescorerBuilder.NAME,
in -> new InferenceRescorerBuilder(in, modelLoadingService),
parser -> InferenceRescorerBuilder.fromXContent(parser, modelLoadingService)));
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
if (this.enabled == false) {
Expand Down Expand Up @@ -450,6 +460,7 @@ public Set<DiscoveryNodeRole> getRoles() {
private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
private final SetOnce<ActionFilter> mlUpgradeModeActionFilter = new SetOnce<>();
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();

public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
Expand Down Expand Up @@ -678,6 +689,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
trainedModelStatsService,
settings,
clusterService.getNodeName());
this.modelLoadingService.set(modelLoadingService);

// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
Expand Down Expand Up @@ -52,7 +52,7 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li

Response.Builder responseBuilder = Response.builder();

ActionListener<Model> getModelListener = ActionListener.wrap(
ActionListener<LocalModel> getModelListener = ActionListener.wrap(
model -> {
TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor =
new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;

Expand All @@ -32,7 +32,6 @@ public class LocalModel implements Model {

private final TrainedModelDefinition trainedModelDefinition;
private final String modelId;
private final String nodeId;
private final Set<String> fieldNames;
private final Map<String, String> defaultFieldMap;
private final InferenceStats.Accumulator statsAccumulator;
Expand All @@ -50,7 +49,6 @@ public LocalModel(String modelId,
TrainedModelStatsService trainedModelStatsService ) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.nodeId = nodeId;
this.fieldNames = new HashSet<>(input.getFieldNames());
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId);
this.trainedModelStatsService = trainedModelStatsService;
Expand All @@ -68,6 +66,10 @@ public String getModelId() {
return modelId;
}

public Set<String> getFieldNames() {
return fieldNames;
}

@Override
public InferenceStats getLatestStatsAndReset() {
return statsAccumulator.currentStatsAndReset();
Expand Down Expand Up @@ -99,14 +101,23 @@ void persistStats(boolean flush) {

@Override
public void infer(Map<String, Object> fields, InferenceConfigUpdate update, ActionListener<InferenceResults> listener) {
try {
InferenceResults result = infer(fields, update);
listener.onResponse(result);
} catch (Exception e) {
listener.onFailure(e);
}
}

public InferenceResults infer(Map<String, Object> fields, InferenceConfigUpdate update) {
if (update.isSupported(this.inferenceConfig) == false) {
listener.onFailure(ExceptionsHelper.badRequestException(
throw ExceptionsHelper.badRequestException(
"Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]",
this.modelId,
this.inferenceConfig.getName(),
update.getName()));
return;
update.getName());
}

try {
statsAccumulator.incInference();
currentInferenceCount.increment();
Expand All @@ -119,18 +130,17 @@ public void infer(Map<String, Object> fields, InferenceConfigUpdate update, Acti
if (shouldPersistStats) {
persistStats(false);
}
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId));
}
InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig));
if (shouldPersistStats) {
persistStats(false);
}
listener.onResponse(inferenceResults);

return inferenceResults;
} catch (Exception e) {
statsAccumulator.incFailure();
listener.onFailure(e);
throw e;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public class ModelLoadingService implements ClusterStateListener {
private final TrainedModelStatsService modelStatsService;
private final Cache<String, LocalModel> localModelCache;
private final Set<String> referencedModels = new HashSet<>();
private final Map<String, Queue<ActionListener<Model>>> loadingListeners = new HashMap<>();
private final Map<String, Queue<ActionListener<LocalModel>>> loadingListeners = new HashMap<>();
private final TrainedModelProvider provider;
private final Set<String> shouldNotAudit;
private final ThreadPool threadPool;
Expand Down Expand Up @@ -140,7 +140,7 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider,
* @param modelId the model to get
* @param modelActionListener the listener to alert when the model has been retrieved.
*/
public void getModel(String modelId, ActionListener<Model> modelActionListener) {
public void getModel(String modelId, ActionListener<LocalModel> modelActionListener) {
LocalModel cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
Expand Down Expand Up @@ -178,9 +178,9 @@ public void getModel(String modelId, ActionListener<Model> modelActionListener)
* Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded
* Returns false if the model is not loaded or actively being loaded
*/
private boolean loadModelIfNecessary(String modelId, ActionListener<Model> modelActionListener) {
private boolean loadModelIfNecessary(String modelId, ActionListener<LocalModel> modelActionListener) {
synchronized (loadingListeners) {
Model cachedModel = localModelCache.get(modelId);
LocalModel cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
return true;
Expand Down Expand Up @@ -219,7 +219,7 @@ private void loadModel(String modelId) {
}

private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) throws IOException {
Queue<ActionListener<Model>> listeners;
Queue<ActionListener<LocalModel>> listeners;
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry);
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) :
Expand All @@ -242,13 +242,13 @@ private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelCo
localModelCache.put(modelId, loadedModel);
shouldNotAudit.remove(modelId);
} // synchronized (loadingListeners)
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
for (ActionListener<LocalModel> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
listener.onResponse(loadedModel);
}
}

private void handleLoadFailure(String modelId, Exception failure) {
Queue<ActionListener<Model>> listeners;
Queue<ActionListener<LocalModel>> listeners;
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
if (listeners == null) {
Expand All @@ -257,7 +257,7 @@ private void handleLoadFailure(String modelId, Exception failure) {
} // synchronized (loadingListeners)
// If we failed to load and there were listeners present, that means that this model is referenced by a processor
// Alert the listeners to the failure
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
for (ActionListener<LocalModel> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
listener.onFailure(failure);
}
}
Expand Down Expand Up @@ -294,7 +294,7 @@ public void clusterChanged(ClusterChangedEvent event) {
return;
}
// The listeners still waiting for a model and we are canceling the load?
List<Tuple<String, List<ActionListener<Model>>>> drainWithFailure = new ArrayList<>();
List<Tuple<String, List<ActionListener<LocalModel>>>> drainWithFailure = new ArrayList<>();
Set<String> referencedModelsBeforeClusterState = null;
Set<String> loadingModelBeforeClusterState = null;
Set<String> removedModels = null;
Expand Down Expand Up @@ -337,11 +337,11 @@ public void clusterChanged(ClusterChangedEvent event) {
referencedModels);
}
}
for (Tuple<String, List<ActionListener<Model>>> modelAndListeners : drainWithFailure) {
for (Tuple<String, List<ActionListener<LocalModel>>> modelAndListeners : drainWithFailure) {
final String msg = new ParameterizedMessage(
"Cancelling load of model [{}] as it is no longer referenced by a pipeline",
modelAndListeners.v1()).getFormat();
for (ActionListener<Model> listener : modelAndListeners.v2()) {
for (ActionListener<LocalModel> listener : modelAndListeners.v2()) {
listener.onFailure(new ElasticsearchException(msg));
}
}
Expand Down Expand Up @@ -430,7 +430,7 @@ private static InferenceConfig inferenceConfigFromTargetType(TargetType targetTy
* @param modelId Model Id
* @param modelLoadedListener To be notified
*/
void addModelLoadedListener(String modelId, ActionListener<Model> modelLoadedListener) {
void addModelLoadedListener(String modelId, ActionListener<LocalModel> modelLoadedListener) {
synchronized (loadingListeners) {
loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
if (listenerQueue == null) {
Expand Down
Loading