diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java new file mode 100644 index 0000000000000..ff50aaa6ab558 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; + +public class EmptyConfigUpdate implements InferenceConfigUpdate { + public static final ParseField NAME = new ParseField("empty"); + + private static final ObjectParser PARSER = + new ObjectParser<>(NAME.getPreferredName(), EmptyConfigUpdate::new); + + public static EmptyConfigUpdate fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public EmptyConfigUpdate() { + } + + public EmptyConfigUpdate(StreamInput in) { + } + + @Override + public InferenceConfig apply(InferenceConfig originalConfig) { + return originalConfig; + } + + @Override + public InferenceConfig toConfig() { + return RegressionConfig.EMPTY_PARAMS; + } + + @Override + public boolean isSupported(InferenceConfig config) { + return true; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + // Equal if o is not null and the same class + return (o == null || getClass() != o.getClass()) == false; + } + + @Override + public int hashCode() { + return super.hashCode(); + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 84ea1372cb1aa..9cc8e0c90e7ec 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -11,7 +11,7 @@ dependencies { // bring in machine learning rest test suite restResources { restApi { - includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest' + includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest', 'bulk' includeXpack 'ml', 'cat' } restTests { @@ -206,6 +206,7 @@ integTest.runner { 'ml/validate_detector/Test invalid detector', 'ml/delete_forecast/Test delete on _all forecasts not allow no forecasts', 'ml/delete_forecast/Test delete forecast on missing forecast', + 'ml/rescore/Test rescore', 'ml/set_upgrade_mode/Attempt to open job when upgrade_mode is enabled', 'ml/set_upgrade_mode/Setting upgrade_mode to enabled', 'ml/set_upgrade_mode/Setting upgrade mode to disabled from enabled', diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 56c9623dc439b..54e7e2da31d38 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -53,6 +53,7 @@ import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.rest.RestController; @@ -229,6 +230,7 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.inference.search.InferenceRescorerBuilder; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -338,7 +340,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; -public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -362,6 +364,14 @@ protected Setting roleSetting() { }; + @Override + public List> getRescorers() { + return Collections.singletonList( + new RescorerSpec<>(InferenceRescorerBuilder.NAME, + in -> new InferenceRescorerBuilder(in, modelLoadingService), + parser -> InferenceRescorerBuilder.fromXContent(parser, modelLoadingService))); + } + @Override public Map getProcessors(Processor.Parameters parameters) { if (this.enabled == false) { @@ -450,6 +460,7 @@ public Set getRoles() { private final SetOnce dataFrameAnalyticsAuditor = new SetOnce<>(); private final SetOnce memoryTracker = new SetOnce<>(); private final SetOnce mlUpgradeModeActionFilter = new SetOnce<>(); + private final SetOnce modelLoadingService = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -678,6 +689,7 @@ public Collection createComponents(Client client, ClusterService cluster trainedModelStatsService, settings, clusterService.getNodeName()); + this.modelLoadingService.set(modelLoadingService); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index d229f4decbee7..698b353aefbc2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; @@ -52,7 +52,7 @@ protected void doExecute(Task task, Request request, ActionListener li Response.Builder responseBuilder = Response.builder(); - ActionListener getModelListener = ActionListener.wrap( + ActionListener getModelListener = ActionListener.wrap( model -> { TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index e7da7a36184f8..9758b1f8a078b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -8,15 +8,15 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.utils.MapHelper; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; @@ -32,7 +32,6 @@ public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; private final String modelId; - private final String nodeId; private final Set fieldNames; private final Map defaultFieldMap; private final InferenceStats.Accumulator statsAccumulator; @@ -50,7 +49,6 @@ public LocalModel(String modelId, TrainedModelStatsService trainedModelStatsService ) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; - this.nodeId = nodeId; this.fieldNames = new HashSet<>(input.getFieldNames()); this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId); this.trainedModelStatsService = trainedModelStatsService; @@ -68,6 +66,10 @@ public String getModelId() { return modelId; } + public Set getFieldNames() { + return fieldNames; + } + @Override public InferenceStats getLatestStatsAndReset() { return statsAccumulator.currentStatsAndReset(); @@ -99,14 +101,23 @@ void persistStats(boolean flush) { @Override public void infer(Map fields, InferenceConfigUpdate update, ActionListener listener) { + try { + InferenceResults result = infer(fields, update); + listener.onResponse(result); + } catch (Exception e) { + listener.onFailure(e); + } + } + + public InferenceResults infer(Map fields, InferenceConfigUpdate update) { if (update.isSupported(this.inferenceConfig) == false) { - listener.onFailure(ExceptionsHelper.badRequestException( + throw ExceptionsHelper.badRequestException( "Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]", this.modelId, this.inferenceConfig.getName(), - update.getName())); - return; + update.getName()); } + try { statsAccumulator.incInference(); currentInferenceCount.increment(); @@ -119,18 +130,17 @@ public void infer(Map fields, InferenceConfigUpdate update, Acti if (shouldPersistStats) { persistStats(false); } - listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); - return; + return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)); } InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig)); if (shouldPersistStats) { persistStats(false); } - listener.onResponse(inferenceResults); + + return inferenceResults; } catch (Exception e) { statsAccumulator.incFailure(); - listener.onFailure(e); + throw e; } } - } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 5053972e50bec..dffa3342585e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -88,7 +88,7 @@ public class ModelLoadingService implements ClusterStateListener { private final TrainedModelStatsService modelStatsService; private final Cache localModelCache; private final Set referencedModels = new HashSet<>(); - private final Map>> loadingListeners = new HashMap<>(); + private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; private final Set shouldNotAudit; private final ThreadPool threadPool; @@ -140,7 +140,7 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider, * @param modelId the model to get * @param modelActionListener the listener to alert when the model has been retrieved. */ - public void getModel(String modelId, ActionListener modelActionListener) { + public void getModel(String modelId, ActionListener modelActionListener) { LocalModel cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); @@ -178,9 +178,9 @@ public void getModel(String modelId, ActionListener modelActionListener) * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded * Returns false if the model is not loaded or actively being loaded */ - private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { synchronized (loadingListeners) { - Model cachedModel = localModelCache.get(modelId); + LocalModel cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); return true; @@ -219,7 +219,7 @@ private void loadModel(String modelId) { } private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) throws IOException { - Queue> listeners; + Queue> listeners; trainedModelConfig.ensureParsedDefinition(namedXContentRegistry); InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : @@ -242,13 +242,13 @@ private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelCo localModelCache.put(modelId, loadedModel); shouldNotAudit.remove(modelId); } // synchronized (loadingListeners) - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onResponse(loadedModel); } } private void handleLoadFailure(String modelId, Exception failure) { - Queue> listeners; + Queue> listeners; synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); if (listeners == null) { @@ -257,7 +257,7 @@ private void handleLoadFailure(String modelId, Exception failure) { } // synchronized (loadingListeners) // If we failed to load and there were listeners present, that means that this model is referenced by a processor // Alert the listeners to the failure - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onFailure(failure); } } @@ -294,7 +294,7 @@ public void clusterChanged(ClusterChangedEvent event) { return; } // The listeners still waiting for a model and we are canceling the load? - List>>> drainWithFailure = new ArrayList<>(); + List>>> drainWithFailure = new ArrayList<>(); Set referencedModelsBeforeClusterState = null; Set loadingModelBeforeClusterState = null; Set removedModels = null; @@ -337,11 +337,11 @@ public void clusterChanged(ClusterChangedEvent event) { referencedModels); } } - for (Tuple>> modelAndListeners : drainWithFailure) { + for (Tuple>> modelAndListeners : drainWithFailure) { final String msg = new ParameterizedMessage( "Cancelling load of model [{}] as it is no longer referenced by a pipeline", modelAndListeners.v1()).getFormat(); - for (ActionListener listener : modelAndListeners.v2()) { + for (ActionListener listener : modelAndListeners.v2()) { listener.onFailure(new ElasticsearchException(msg)); } } @@ -430,7 +430,7 @@ private static InferenceConfig inferenceConfigFromTargetType(TargetType targetTy * @param modelId Model Id * @param modelLoadedListener To be notified */ - void addModelLoadedListener(String modelId, ActionListener modelLoadedListener) { + void addModelLoadedListener(String modelId, ActionListener modelLoadedListener) { synchronized (loadingListeners) { loadingListeners.compute(modelId, (modelKey, listenerQueue) -> { if (listenerQueue == null) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java new file mode 100644 index 0000000000000..5e95bc394fd23 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -0,0 +1,135 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.index.fielddata.FieldData; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; +import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.Rescorer; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class InferenceRescorer implements Rescorer { + + private static final Logger logger = LogManager.getLogger(InferenceRescorer.class); + + private final LocalModel model; + private final InferenceConfigUpdate update; + private final Map fieldMap; + private final InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings; + + InferenceRescorer( + LocalModel model, + InferenceConfigUpdate update, + Map fieldMap, + InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings + ) { + this.model = model; + this.update = update; + this.fieldMap = fieldMap; + this.scoreModeSettings = scoreModeSettings; + } + + @Override + public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException { + + // Copy ScoreDoc[] and sort by ascending docID: + ScoreDoc[] sortedHits = topDocs.scoreDocs.clone(); + Comparator docIdComparator = Comparator.comparingInt(sd -> sd.doc); + Arrays.sort(sortedHits, docIdComparator); + + Set fieldsToRead = new HashSet<>(model.getFieldNames()); + // field map is fieldname in doc -> fieldname expected by model + for (Map.Entry entry : fieldMap.entrySet()) { + if (fieldsToRead.contains(entry.getValue())) { + // replace the model fieldname with the doc fieldname + fieldsToRead.remove(entry.getValue()); + fieldsToRead.add(entry.getKey()); + } + } + + List leaves = searcher.getIndexReader().getContext().leaves(); + Map fields = new HashMap<>(); + + int currentReader = 0; + int endDoc = 0; + LeafReaderContext readerContext = null; + + for (int hitIndex = 0; hitIndex < sortedHits.length; hitIndex++) { + ScoreDoc hit = sortedHits[hitIndex]; + int docId = hit.doc; + + // get the context for this docId + while (docId >= endDoc) { + readerContext = leaves.get(currentReader); + currentReader++; + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + + for (String field : fieldsToRead) { + SortedNumericDocValues docValuesIter = DocValues.getSortedNumeric(readerContext.reader(), field); + SortedNumericDoubleValues doubles = FieldData.sortableLongBitsToDoubles(docValuesIter); + if (doubles.advanceExact(hit.doc)) { + double val = doubles.nextValue(); + fields.put(fieldMap.getOrDefault(field, field), val); + } else if (docValuesIter.docID() == DocIdSetIterator.NO_MORE_DOCS) { + logger.warn("No more docs for field {}, doc {}", field, hit.doc); + fields.remove(field); + } else { + logger.warn("no value for field {}, doc {}", field, hit.doc); + fields.remove(field); + } + } + + InferenceResults infer = model.infer(fields, update); + if (infer instanceof WarningInferenceResults) { + String message = ((WarningInferenceResults) infer).getWarning(); + logger.warn("inference error: " + message); + throw new ElasticsearchException(message); + } else { + SingleValueInferenceResults regressionResult = (SingleValueInferenceResults) infer; + + float combinedScore = scoreModeSettings.scoreMode.combine( + hit.score * scoreModeSettings.queryWeight, + regressionResult.value().floatValue() * scoreModeSettings.modelWeight + ); + + sortedHits[hitIndex] = new ScoreDoc(hit.doc, combinedScore); + } + } + + return new TopDocs(topDocs.totalHits, sortedHits); + } + + @Override + public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext, Explanation sourceExplanation) { + return Explanation.match(1.0, "because"); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java new file mode 100644 index 0000000000000..5bd3b62c96d11 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java @@ -0,0 +1,280 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.search.rescore.QueryRescoreMode; +import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class InferenceRescorerBuilder extends RescorerBuilder { + + public static final String NAME = "ml_rescore"; + + public static final ParseField MODEL_ID = new ParseField("model_id"); + private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config"); + private static final ParseField FIELD_MAPPINGS = new ParseField("field_mappings"); + + private static final ParseField QUERY_WEIGHT = new ParseField("query_weight"); + private static final ParseField MODEL_WEIGHT = new ParseField("model_weight"); + private static final ParseField SCORE_MODE = new ParseField("score_mode"); + + private static final float DEFAULT_QUERY_WEIGHT = 1.0f; + private static final float DEFAULT_MODEL_WEIGHT = 1.0f; + private static final QueryRescoreMode DEFAULT_SCORE_MODE = QueryRescoreMode.Total; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser> PARSER = + new ConstructingObjectParser<>(NAME, false, + (args, context) -> + new InferenceRescorerBuilder((String) args[0], context, (InferenceConfig) args[1], (Map) args[2]) + ); + + static { + PARSER.declareString(constructorArg(), MODEL_ID); + PARSER.declareNamedObject(optionalConstructorArg(), + (p, c, n) -> p.namedObject(StrictlyParsedInferenceConfig.class, n, c), INFERENCE_CONFIG); + PARSER.declareField(optionalConstructorArg(), (p, c) -> p.mapStrings(), FIELD_MAPPINGS, ObjectParser.ValueType.OBJECT); + PARSER.declareFloat(InferenceRescorerBuilder::setQueryWeight, QUERY_WEIGHT); + PARSER.declareFloat(InferenceRescorerBuilder::setModelWeight, MODEL_WEIGHT); + PARSER.declareString((builder, mode) -> builder.setScoreMode(QueryRescoreMode.fromString(mode)), SCORE_MODE); + } + + public static InferenceRescorerBuilder fromXContent(XContentParser parser, SetOnce modelLoadingService) { + return PARSER.apply(parser, modelLoadingService); + } + + private final String modelId; + private final SetOnce modelLoadingService; + private final InferenceConfig inferenceConfig; + private final Map fieldMap; + + private LocalModel model; + private Supplier modelSupplier; + + private float queryWeight = DEFAULT_QUERY_WEIGHT; + private float modelWeight = DEFAULT_MODEL_WEIGHT; + private QueryRescoreMode scoreMode = DEFAULT_SCORE_MODE; + + public InferenceRescorerBuilder(String modelId, + SetOnce modelLoadingService, + InferenceConfig config, + @Nullable Map fieldMap) { + this.modelId = modelId; + this.modelLoadingService = modelLoadingService; + this.inferenceConfig = config; + this.fieldMap = fieldMap; + } + + private InferenceRescorerBuilder(String modelId, + SetOnce modelLoadingService, + @Nullable InferenceConfig config, + @Nullable Map fieldMap, + Supplier modelSupplier + ) { + this(modelId, modelLoadingService, config, fieldMap); + this.modelSupplier = modelSupplier; + } + + private InferenceRescorerBuilder(String modelId, + SetOnce modelLoadingService, + @Nullable InferenceConfig config, + @Nullable Map fieldMap, + LocalModel model + ) { + this(modelId, modelLoadingService, config, fieldMap); + this.model = model; + } + + public InferenceRescorerBuilder(StreamInput in, SetOnce modelLoadingService) throws IOException { + super(in); + modelId = in.readString(); + inferenceConfig = in.readOptionalNamedWriteable(InferenceConfig.class); + boolean readMap = in.readBoolean(); + if (readMap) { + fieldMap = in.readMap(StreamInput::readString, StreamInput::readString); + } else { + fieldMap = null; + } + queryWeight = in.readFloat(); + modelWeight = in.readFloat(); + scoreMode = QueryRescoreMode.readFromStream(in); + + this.modelLoadingService = modelLoadingService; + } + + void setQueryWeight(float queryWeight) { + this.queryWeight = queryWeight; + } + + void setModelWeight(float modelWeight) { + this.modelWeight = modelWeight; + } + + void setScoreMode(QueryRescoreMode scoreMode) { + this.scoreMode = scoreMode; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalNamedWriteable(inferenceConfig); + boolean fieldMapPresent = fieldMap != null; + out.writeBoolean(fieldMapPresent); + if (fieldMapPresent) { + out.writeMap(fieldMap, StreamOutput::writeString, StreamOutput::writeString); + } + out.writeFloat(queryWeight); + out.writeFloat(modelWeight); + scoreMode.writeTo(out); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID.getPreferredName(), modelId); + if (inferenceConfig != null) { + builder.startObject(INFERENCE_CONFIG.getPreferredName()); + builder.field(inferenceConfig.getName(), inferenceConfig); + builder.endObject(); + } + if (fieldMap != null) { + builder.field(FIELD_MAPPINGS.getPreferredName(), fieldMap); + } + builder.field(QUERY_WEIGHT.getPreferredName(), queryWeight); + builder.field(MODEL_WEIGHT.getPreferredName(), modelWeight); + builder.field(SCORE_MODE.getPreferredName(), scoreMode.name().toLowerCase(Locale.ROOT)); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public RescorerBuilder rewrite(QueryRewriteContext ctx) { + assert modelId != null; + + if (modelSupplier != null) { + LocalModel m = modelSupplier.get(); + if (m == null) { + return this; + } else { + return copyScoringSettings( + new InferenceRescorerBuilder(modelId, modelLoadingService, inferenceConfig, fieldMap, m)); + } + } else if (model == null) { + + SetOnce modelHolder = new SetOnce<>(); + + ctx.registerAsyncAction(((client, actionListener) -> + modelLoadingService.get().getModel(modelId, ActionListener.wrap( + m -> { + modelHolder.set(m); + actionListener.onResponse(null); + }, + actionListener::onFailure)) + )); + + return copyScoringSettings( + new InferenceRescorerBuilder(modelId, modelLoadingService, inferenceConfig, fieldMap, modelHolder::get)); + } + return this; + } + + private InferenceRescorerBuilder copyScoringSettings(InferenceRescorerBuilder target) { + target.setQueryWeight(queryWeight); + target.setModelWeight(modelWeight); + target.setScoreMode(scoreMode); + return target; + } + + @Override + protected RescoreContext innerBuildContext(int windowSize, QueryShardContext context) { + assert model != null; + + InferenceConfigUpdate update; + if (inferenceConfig == null) { + update = new EmptyConfigUpdate(); + } else if (inferenceConfig instanceof RegressionConfig) { + update = RegressionConfigUpdate.fromConfig((RegressionConfig) inferenceConfig); + } else { + // TODO better message + throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", + inferenceConfig.getName(), RegressionConfig.NAME.getPreferredName()); + } + + return new RescoreContext(windowSize, new InferenceRescorer(model, update, fieldMap, scoreModeSettings())); + } + + static class ScoreModeSettings { + float queryWeight; + float modelWeight; + QueryRescoreMode scoreMode; + + ScoreModeSettings(float queryWeight, float modelWeight, QueryRescoreMode scoreMode) { + this.queryWeight = queryWeight; + this.modelWeight = modelWeight; + this.scoreMode = scoreMode; + } + } + + private ScoreModeSettings scoreModeSettings() { + return new ScoreModeSettings(this.queryWeight, this.modelWeight, this.scoreMode); + } + + @Override + public final int hashCode() { + return Objects.hash(windowSize, modelId, inferenceConfig, fieldMap, queryWeight, modelWeight, scoreMode); + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + InferenceRescorerBuilder other = (InferenceRescorerBuilder) obj; + return Objects.equals(windowSize, other.windowSize) + && Objects.equals(modelId, other.modelId) + && Objects.equals(inferenceConfig, other.inferenceConfig) + && Objects.equals(fieldMap, other.fieldMap) + && Objects.equals(queryWeight, other.queryWeight) + && Objects.equals(modelWeight, other.modelWeight) + && Objects.equals(scoreMode, other.scoreMode); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 8753db3e878ea..033d50c667dbb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -125,7 +125,7 @@ public void testGetCachedModels() throws Exception { String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -138,7 +138,7 @@ public void testGetCachedModels() throws Exception { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -190,7 +190,7 @@ public void testMaxCachedLimitReached() throws Exception { for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -213,7 +213,7 @@ public boolean matches(final Object o) { // Load model 3, should invalidate 1 and 2 for(int i = 0; i < 10; i++) { - PlainActionFuture future3 = new PlainActionFuture<>(); + PlainActionFuture future3 = new PlainActionFuture<>(); modelLoadingService.getModel(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } @@ -234,7 +234,7 @@ public boolean matches(final Object o) { // Load model 1, should invalidate 3 for(int i = 0; i < 10; i++) { - PlainActionFuture future1 = new PlainActionFuture<>(); + PlainActionFuture future1 = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } @@ -248,7 +248,7 @@ public boolean matches(final Object o) { // Load model 2 for(int i = 0; i < 10; i++) { - PlainActionFuture future2 = new PlainActionFuture<>(); + PlainActionFuture future2 = new PlainActionFuture<>(); modelLoadingService.getModel(model2, future2); assertThat(future2.get(), is(not(nullValue()))); } @@ -259,7 +259,7 @@ public boolean matches(final Object o) { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -286,7 +286,7 @@ public void testWhenCacheEnabledButNotIngestNode() throws Exception { modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); for(int i = 0; i < 10; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future); assertThat(future.get(), is(not(nullValue()))); } @@ -309,7 +309,7 @@ public void testGetCachedMissingModel() throws Exception { "test-node"); modelLoadingService.clusterChanged(ingestChangedEvent(model)); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { @@ -336,7 +336,7 @@ public void testGetMissingModel() { Settings.EMPTY, "test-node"); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { future.get(); @@ -360,7 +360,7 @@ public void testGetModelEagerly() throws Exception { "test-node"); for(int i = 0; i < 3; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -458,7 +458,7 @@ private void onFailure(Exception e) { fail(e.getMessage()); } - ActionListener actionListener() { + ActionListener actionListener() { return ActionListener.wrap(this::onModelLoaded, this::onFailure); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java new file mode 100644 index 0000000000000..fd5ce7de75082 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.rescore.QueryRescoreMode; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class InferenceRescorerBuilderTests extends AbstractSerializingTestCase { + + @Override + protected InferenceRescorerBuilder doParseInstance(XContentParser parser) { + return InferenceRescorerBuilder.fromXContent(parser, new SetOnce<>(mock(ModelLoadingService.class))); + } + + @Override + protected Writeable.Reader instanceReader() { + return in -> new InferenceRescorerBuilder(in, new SetOnce<>(mock(ModelLoadingService.class))); + } + + @Override + protected InferenceRescorerBuilder createTestInstance() { + InferenceConfig config = null; + + if (randomBoolean()) { + if (randomBoolean()) { + config = ClassificationConfigTests.randomClassificationConfig(); + } else { + config = RegressionConfigTests.randomRegressionConfig(); + } + } + + InferenceRescorerBuilder builder = new InferenceRescorerBuilder( + randomAlphaOfLength(8), + new SetOnce<>(mock(ModelLoadingService.class)), + config, randomMap()); + + if (randomBoolean()) { + builder.setQueryWeight((float) randomDoubleBetween(0.0, 1.0, true)); + builder.setModelWeight((float) randomDoubleBetween(0.0, 2.0, true)); + builder.setScoreMode(randomFrom(QueryRescoreMode.values())); + } + + return builder; + } + + private Map randomMap() { + int numEntries = randomIntBetween(0, 6); + Map result = new HashMap<>(); + for (int i = 0; i < numEntries; i++) { + result.put("field" + i, randomAlphaOfLength(5)); + } + + return result; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml new file mode 100644 index 0000000000000..86ac893d62b0d --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml @@ -0,0 +1,106 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: a-complex-regression-model + body: > + { + "description": "super complex model for tests", + "input": {"field_names": ["decider"]}, + "inference_config": {"regression": {}}, + "definition": { + "trained_model": { + "ensemble": { + "feature_names": [], + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": [ + "decider" + ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 38, + "decision_type": "lte", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 5.0 + }, + { + "node_index": 2, + "leaf_value": 2.0 + } + ], + "target_type": "regression" + } + } + ] + } + } + } + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.create: + index: store + body: + mappings: + properties: + goods: + type: text + size: + type: double + + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + Content-Type: application/json + bulk: + index: store + refresh: true + body: + - '{ "index": {} }' + - '{ "goods": "television", "size": 32.0 }' + - '{ "index": {} }' + - '{ "goods": "VCR", "size": 0 }' + - '{ "index": {} }' + - '{ "goods": "widescreen television", "size": 40.0 }' + +--- +"Test rescore": + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + Content-Type: application/json + search: + index: store + body: | + { + "query": { "term" : { "goods": {"value": "television"} } }, + "rescore": { + "ml_rescore" : { + "model_id": "a-complex-regression-model", + "field_mappings": {"size": "decider"}, + "inference_config": { "regression": {} }, + "model_weight": 2.0, + "query_weight": 0.0 + } + } + } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._score: 4.0 }