From 2ed40fea2272ac275ba391f2c27105974e31aaac Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 28 Jan 2020 16:32:13 +0000 Subject: [PATCH 01/16] Rescore --- .../xpack/ml/MachineLearning.java | 23 +- .../inference/loadingservice/LocalModel.java | 14 +- .../search/InferenceFetchSubPhase.java | 48 ++++ .../search/InferenceQueryBuilder.java | 68 +++++ .../search/InferenceRescorerBuilder.java | 243 ++++++++++++++++++ .../search/InferenceRescorerBuilderTests.java | 69 +++++ 6 files changed, 461 insertions(+), 4 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 56c9623dc439b..aed1dd933929b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -53,11 +53,13 @@ import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.script.ScriptService; +import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; @@ -229,6 +231,9 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.inference.search.InferenceFetchSubPhase; +import org.elasticsearch.xpack.ml.inference.search.InferenceQueryBuilder; +import org.elasticsearch.xpack.ml.inference.search.InferenceRescorerBuilder; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -338,7 +343,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; -public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -362,6 +367,22 @@ protected Setting roleSetting() { }; + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return Collections.singletonList(new InferenceFetchSubPhase()); + } + + @Override + public List> getRescorers() { + return Collections.singletonList( + new RescorerSpec<>(InferenceRescorerBuilder.NAME, InferenceRescorerBuilder::new, InferenceRescorerBuilder::fromXContent)); + } + + @Override + public List> getQueries() { + return Collections.singletonList( + new QuerySpec<>(InferenceQueryBuilder.NAME, InferenceQueryBuilder::new, InferenceQueryBuilder::fromXContent)); + } + @Override public Map getProcessors(Processor.Parameters parameters) { if (this.enabled == false) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index e7da7a36184f8..36076e0e257bb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -8,15 +8,15 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.utils.MapHelper; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; @@ -133,4 +133,12 @@ public void infer(Map fields, InferenceConfigUpdate update, Acti } } + public InferenceResults infer(Map fields, InferenceConfig config) { + if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { + return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)); + } else { + return trainedModelDefinition.infer(fields, config); + } + } + } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java new file mode 100644 index 0000000000000..b3606e1fded89 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class InferenceFetchSubPhase implements FetchSubPhase { + + private static Logger logger = LogManager.getLogger(InferenceFetchSubPhase.class); + + public InferenceFetchSubPhase() { + logger.info("creating InferenceFetchSubPhase"); + } + + public void hitExecute(SearchContext context, HitContext hitContext) throws IOException { + logger.info("hitcontex"); + + Map fields = new HashMap<>(hitContext.hit().getFields()); + fields.put("Hello", new DocumentField("teddy", List.of("ruxpin"))); + hitContext.hit().fields(fields); + } + + public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException { + // do something to a search hit + + logger.info("modifying hits"); + + for (SearchHit hit : hits) { + Map fields = new HashMap<>(hit.getFields()); + fields.put("chocolate", new DocumentField("foo", List.of("bar"))); + hit.fields(fields); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java new file mode 100644 index 0000000000000..3c1f1f0f79ff3 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.lucene.search.Query; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryShardContext; + +import java.io.IOException; +import java.util.Objects; + +public class InferenceQueryBuilder extends AbstractQueryBuilder { + + public static final String NAME = "ml_magic"; + + private final String modelId; + + public static InferenceQueryBuilder fromXContent(XContentParser parser) throws IOException { + return null; + } + + public InferenceQueryBuilder(String modelId) { + this.modelId = modelId; + } + + public InferenceQueryBuilder(StreamInput in) throws IOException { + modelId = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(modelId); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + return Queries.newMatchAllQuery(); + } + + @Override + protected boolean doEquals(InferenceQueryBuilder other) { + return Objects.equals(this.modelId, other.modelId); + } + + @Override + protected int doHashCode() { + return Objects.hash(modelId); + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java new file mode 100644 index 0000000000000..17450bc4e9ca6 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java @@ -0,0 +1,243 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.Rescorer; +import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class InferenceRescorerBuilder extends RescorerBuilder { + + public static final String NAME = "ml_rescore"; + + private static final Logger logger = LogManager.getLogger(InferenceRescorerBuilder.class); + + public static final ParseField MODEL_ID = new ParseField("model_id"); + public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config"); + public static final ParseField FIELD_MAPPINGS = new ParseField("field_mappings"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + args -> new InferenceRescorerBuilder((String) args[0], (List) args[1], (Map) args[2])); + + static { + PARSER.declareString(constructorArg(), MODEL_ID); + PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); + PARSER.declareField(optionalConstructorArg(), (p, c) -> p.mapStrings(), FIELD_MAPPINGS, ObjectParser.ValueType.OBJECT); + } + + public static InferenceRescorerBuilder fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String modelId; + private final InferenceConfig inferenceConfig; + private final Map fieldMap; + + private LocalModel model; + private Supplier modelSupplier; + + private InferenceRescorerBuilder(String modelId, @Nullable List config, @Nullable Map fieldMap) { + this.modelId = modelId; + if (config != null) { + assert config.size() == 1; + this.inferenceConfig = config.get(0); + } else { + this.inferenceConfig = null; + } + this.fieldMap = fieldMap; + } + + InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap) { + this.modelId = modelId; + this.inferenceConfig = config; + this.fieldMap = fieldMap; + } + + InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, + Supplier modelSupplier) { + this(modelId, config, fieldMap); + this.modelSupplier = modelSupplier; + } + + InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, + LocalModel model) { + this(modelId, config, fieldMap); + this.model = Objects.requireNonNull(model); + } + + public InferenceRescorerBuilder(StreamInput in) throws IOException { + super(in); + modelId = in.readString(); + inferenceConfig = in.readOptionalNamedWriteable(InferenceConfig.class); + boolean readMap = in.readBoolean(); + if (readMap) { + fieldMap = in.readMap(StreamInput::readString, StreamInput::readString); + } else { + fieldMap = null; + } + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + if (modelSupplier != null) { + throw new IllegalStateException("can't serialize model supplier. Missing a rewriteAndFetch?"); + } + + out.writeString(modelId); + out.writeOptionalNamedWriteable(inferenceConfig); + boolean fieldMapPresent = fieldMap != null; + out.writeBoolean(fieldMapPresent); + if (fieldMapPresent) { + out.writeMap(fieldMap, StreamOutput::writeString, StreamOutput::writeString); + } + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID.getPreferredName(), modelId); + if (inferenceConfig != null) { + builder.startObject(INFERENCE_CONFIG.getPreferredName()); + builder.field(inferenceConfig.getName(), inferenceConfig); + builder.endObject(); + } + if (fieldMap != null) { + builder.field(FIELD_MAPPINGS.getPreferredName(), fieldMap); + } + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public RescorerBuilder rewrite(QueryRewriteContext ctx) { + + assert modelId != null; + + if (model != null) { + return this; + } else if (modelSupplier != null) { + if (modelSupplier.get() == null) { + return this; + } else { + return new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelSupplier.get()); + } + } else { + SetOnce modelHolder = new SetOnce<>(); + + ctx.registerAsyncAction(((client, actionListener) -> { + TrainedModelProvider modelProvider = new TrainedModelProvider(client, ctx.getXContentRegistry()); + modelProvider.getTrainedModel(modelId, true, ActionListener.wrap( + trainedModel -> { + LocalModel model = new LocalModel(modelId, + trainedModel.ensureParsedDefinition(ctx.getXContentRegistry()).getModelDefinition(), + trainedModel.getInput()); + modelHolder.set(model); + actionListener.onResponse(null); + }, + actionListener::onFailure + )); + })); + + return new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelHolder::get); + } + } + + @Override + protected RescoreContext innerBuildContext(int windowSize, QueryShardContext context) throws IOException { + LocalModel m = (model != null) ? model : modelSupplier.get(); + assert m != null; + return new RescoreContext(windowSize, new InferenceRescorer(m, inferenceConfig, fieldMap)); + } + + @Override + public final int hashCode() { + return Objects.hash(windowSize, modelId, inferenceConfig, fieldMap); + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + InferenceRescorerBuilder other = (InferenceRescorerBuilder) obj; + return Objects.equals(windowSize, other.windowSize) && + Objects.equals(modelId, other.modelId) && + Objects.equals(inferenceConfig, other.inferenceConfig) && + Objects.equals(fieldMap, other.fieldMap); + } + + + private static class InferenceRescorer implements Rescorer { + + private final LocalModel model; + private final InferenceConfig inferenceConfig; + private final Map fieldMap; + + + public InferenceRescorer(LocalModel model, InferenceConfig inferenceConfig, Map fieldMap) { + this.model = model; + this.inferenceConfig = inferenceConfig; + this.fieldMap = fieldMap; + } + + @Override + public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) { + + Map doc = buildDoc(fieldMap); + InferenceResults results = model.infer(doc, inferenceConfig); + + return topDocs; + } + + @Override + public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext, Explanation sourceExplanation) { + return Explanation.match(1.0, "becuase"); + } + + private Map buildDoc(Map fieldMap) { + return Collections.emptyMap(); + } + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java new file mode 100644 index 0000000000000..9fbdb150554af --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; + +import java.util.HashMap; +import java.util.Map; + +public class InferenceRescorerBuilderTests extends AbstractSerializingTestCase { + + @Override + protected InferenceRescorerBuilder doParseInstance(XContentParser parser) { + return InferenceRescorerBuilder.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return InferenceRescorerBuilder::new; + } + + @Override + protected InferenceRescorerBuilder createTestInstance() { + InferenceConfig config = null; + + if (randomBoolean()) { + if (randomBoolean()) { + config = ClassificationConfigTests.randomClassificationConfig(); + } else { + config = RegressionConfigTests.randomRegressionConfig(); + } + } + + return new InferenceRescorerBuilder(randomAlphaOfLength(8), config, randomMap()); + } + + private Map randomMap() { + int numEntries = randomIntBetween(0, 6); + Map result = new HashMap<>(); + for (int i=0; i Date: Tue, 28 Jan 2020 16:39:18 +0000 Subject: [PATCH 02/16] WIP --- .../main/java/org/elasticsearch/xpack/ml/MachineLearning.java | 3 ++- .../xpack/ml/inference/loadingservice/LocalModel.java | 4 ++++ .../xpack/ml/inference/search/InferenceRescorerBuilder.java | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index aed1dd933929b..e533b6251ec9e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -368,7 +368,8 @@ protected Setting roleSetting() { }; public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return Collections.singletonList(new InferenceFetchSubPhase()); +// return Collections.singletonList(new InferenceFetchSubPhase()); + return Collections.emptyList(); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 36076e0e257bb..f902efc787e94 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -68,6 +68,10 @@ public String getModelId() { return modelId; } + public Set getFieldNames() { + return fieldNames; + } + @Override public InferenceStats getLatestStatsAndReset() { return statsAccumulator.currentStatsAndReset(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java index 17450bc4e9ca6..8a73d239a52b7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java @@ -219,11 +219,14 @@ public InferenceRescorer(LocalModel model, InferenceConfig inferenceConfig, Map< this.model = model; this.inferenceConfig = inferenceConfig; this.fieldMap = fieldMap; + String foo = "\.".split() } @Override public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) { + model. + Map doc = buildDoc(fieldMap); InferenceResults results = model.infer(doc, inferenceConfig); From dc262551db61dbda1878f5fd97260873855c7e55 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 30 Jan 2020 10:55:41 +0000 Subject: [PATCH 03/16] Something working --- .../org/elasticsearch/client/TransformIT.java | 18 ---- .../xpack/ml/MachineLearning.java | 1 - .../search/InferenceRescorerBuilder.java | 90 +++++++++++++++---- .../search/InferenceRescorerBuilderTests.java | 1 - 4 files changed, 74 insertions(+), 36 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java index 94341c41685f0..99afbe379d231 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java @@ -163,24 +163,6 @@ public void cleanUpTransformsAndLogAudits() throws Exception { transformsToClean = new ArrayList<>(); waitForPendingTasks(adminClient()); - - // using '*' to make this lenient and do not fail if the audit index does not exist - SearchRequest searchRequest = new SearchRequest(".transform-notifications-*"); - searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(100).sort("timestamp", SortOrder.ASC)); - - for (SearchHit hit : searchAll(searchRequest)) { - Map source = hit.getSourceAsMap(); - String level = (String) source.getOrDefault("level", "info"); - logger.log( - Level.getLevel(level.toUpperCase(Locale.ROOT)), - "Transform audit: [{}] [{}] [{}] [{}]", - Instant.ofEpochMilli((long) source.getOrDefault("timestamp", 0)), - source.getOrDefault("transform_id", "n/a"), - source.getOrDefault("message", "n/a"), - source.getOrDefault("node_name", "n/a") - ); - } - } public void testCreateDelete() throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index e533b6251ec9e..c32086ebf8ad5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -231,7 +231,6 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; -import org.elasticsearch.xpack.ml.inference.search.InferenceFetchSubPhase; import org.elasticsearch.xpack.ml.inference.search.InferenceQueryBuilder; import org.elasticsearch.xpack.ml.inference.search.InferenceRescorerBuilder; import org.elasticsearch.xpack.ml.job.JobManager; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java index 8a73d239a52b7..76f19ba25fa6d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java @@ -8,8 +8,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; @@ -21,21 +27,30 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.fielddata.FieldData; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.rescore.Rescorer; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.io.IOException; -import java.util.Collections; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.Supplier; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; @@ -47,6 +62,8 @@ public class InferenceRescorerBuilder extends RescorerBuilder rewrite(QueryRewriteContext ctx } @Override - protected RescoreContext innerBuildContext(int windowSize, QueryShardContext context) throws IOException { + protected RescoreContext innerBuildContext(int windowSize, QueryShardContext context) { LocalModel m = (model != null) ? model : modelSupplier.get(); assert m != null; return new RescoreContext(windowSize, new InferenceRescorer(m, inferenceConfig, fieldMap)); @@ -215,31 +232,72 @@ private static class InferenceRescorer implements Rescorer { private final Map fieldMap; - public InferenceRescorer(LocalModel model, InferenceConfig inferenceConfig, Map fieldMap) { + private InferenceRescorer(LocalModel model, InferenceConfig inferenceConfig, Map fieldMap) { this.model = model; this.inferenceConfig = inferenceConfig; this.fieldMap = fieldMap; - String foo = "\.".split() + + assert inferenceConfig instanceof RegressionConfig; } @Override - public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) { - - model. + public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException { + + // Copy ScoreDoc[] and sort by ascending docID: + ScoreDoc[] sortedHits = topDocs.scoreDocs.clone(); + Comparator docIdComparator = Comparator.comparingInt(sd -> sd.doc); + Arrays.sort(sortedHits, docIdComparator); + + // field map is fieldname in doc -> fieldname expected by model + Set fieldsToRead = new HashSet<>(model.getFieldNames()); + for (Map.Entry entry : fieldMap.entrySet()) { + if (fieldsToRead.contains(entry.getValue())) { + // replace the model fieldname with the doc fieldname + fieldsToRead.remove(entry.getValue()); + fieldsToRead.add(entry.getKey()); + } + } - Map doc = buildDoc(fieldMap); - InferenceResults results = model.infer(doc, inferenceConfig); + List leaves = searcher.getIndexReader().getContext().leaves(); + Map fields = new HashMap<>(); + for (int i=0; i buildDoc(Map fieldMap) { - return Collections.emptyMap(); + public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext, + Explanation sourceExplanation) { + return Explanation.match(1.0, "because"); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java index 9fbdb150554af..706ed0a8c27bd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.inference.search; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.NamedXContentRegistry; From 148e59b2367fcd6ff0dcbb5cd2bfc6b80b1d2ffa Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 6 Feb 2020 12:21:05 +0000 Subject: [PATCH 04/16] YML tests for rescore --- .../rest-api-spec/test/ml/rescore.yml | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml new file mode 100644 index 0000000000000..f1db571e766a5 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml @@ -0,0 +1,100 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: a-complex-regression-model + body: > + { + "description": "super complex model for tests", + "input": {"field_names": ["decider"]}, + "definition": { + "trained_model": { + "ensemble": { + "feature_names": [], + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": [ + "decider" + ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 38, + "decision_type": "lte", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 5.0 + }, + { + "node_index": 2, + "leaf_value": 2.0 + } + ], + "target_type": "regression" + } + } + ] + } + } + } + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.create: + index: store + body: + mappings: + properties: + goods: + type: text + size: + type: double + + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + Content-Type: application/json + bulk: + index: store + refresh: true + body: | + { "index": {} } + { "goods": "television", "size": 32.0 } + { "index": {} } + { "goods": "VCR", "size": 0 } + { "index": {} } + { "goods": "widescreen television", "size": 40.0 } + +--- +"Test rescore": + + - do: + search: + index: store + body: | + { + "query": { "term" : { "goods": {"value": "television"} } }, + "rescore": { + "ml_rescore" : { + "model_id": "a-complex-regression-model", + "field_mappings": {"size": "decider"}, + "inference_config": { "regression": {} } + } + } + } + - match: { hits.hits.0._score: 5.0 } + - match: { hits.hits.1._score: 2.0 } From a00bd755c10942af68f91c291e459b4d1a6218b3 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 6 Feb 2020 17:45:44 +0000 Subject: [PATCH 05/16] Rescore Mode --- .../inference/search/InferenceRescorer.java | 123 ++++++++++++ .../search/InferenceRescorerBuilder.java | 190 +++++++----------- .../search/InferenceRescorerBuilderTests.java | 10 +- .../rest-api-spec/test/ml/rescore.yml | 8 +- 4 files changed, 213 insertions(+), 118 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java new file mode 100644 index 0000000000000..332980a9495e0 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.index.fielddata.FieldData; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; +import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.Rescorer; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class InferenceRescorer implements Rescorer { + + private static final Logger logger = LogManager.getLogger(InferenceRescorer.class); + + private final LocalModel model; + private final InferenceConfig inferenceConfig; + private final Map fieldMap; + private final InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings; + + + InferenceRescorer(LocalModel model, InferenceConfig inferenceConfig, + Map fieldMap, InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings) { + this.model = model; + this.inferenceConfig = inferenceConfig; + this.fieldMap = fieldMap; + this.scoreModeSettings = scoreModeSettings; + + assert inferenceConfig instanceof RegressionConfig; + } + + @Override + public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException { + + // Copy ScoreDoc[] and sort by ascending docID: + ScoreDoc[] sortedHits = topDocs.scoreDocs.clone(); + Comparator docIdComparator = Comparator.comparingInt(sd -> sd.doc); + Arrays.sort(sortedHits, docIdComparator); + + // field map is fieldname in doc -> fieldname expected by model + Set fieldsToRead = new HashSet<>(model.getFieldNames()); + for (Map.Entry entry : fieldMap.entrySet()) { + if (fieldsToRead.contains(entry.getValue())) { + // replace the model fieldname with the doc fieldname + fieldsToRead.remove(entry.getValue()); + fieldsToRead.add(entry.getKey()); + } + } + + List leaves = searcher.getIndexReader().getContext().leaves(); + Map fields = new HashMap<>(); + for (int i=0; i PARSER = new ConstructingObjectParser<>(NAME, @@ -76,6 +59,9 @@ public class InferenceRescorerBuilder extends RescorerBuilder p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); PARSER.declareField(optionalConstructorArg(), (p, c) -> p.mapStrings(), FIELD_MAPPINGS, ObjectParser.ValueType.OBJECT); + PARSER.declareFloat(InferenceRescorerBuilder::setQueryWeight, QUERY_WEIGHT); + PARSER.declareFloat(InferenceRescorerBuilder::setModelWeight, MODEL_WEIGHT); + PARSER.declareString((builder, mode) -> builder.setScoreMode(QueryRescoreMode.fromString(mode)), SCORE_MODE); } public static InferenceRescorerBuilder fromXContent(XContentParser parser) { @@ -89,6 +75,10 @@ public static InferenceRescorerBuilder fromXContent(XContentParser parser) { private LocalModel model; private Supplier modelSupplier; + private float queryWeight = DEFAULT_QUERY_WEIGHT; + private float modelWeight = DEFAULT_MODEL_WEIGHT; + private QueryRescoreMode scoreMode = DEFAULT_SCORE_MODE; + private InferenceRescorerBuilder(String modelId, @Nullable List config, @Nullable Map fieldMap) { this.modelId = modelId; if (config != null) { @@ -106,13 +96,13 @@ private InferenceRescorerBuilder(String modelId, @Nullable List this.fieldMap = fieldMap; } - InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, + private InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, Supplier modelSupplier) { this(modelId, config, fieldMap); this.modelSupplier = modelSupplier; } - InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, + private InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, LocalModel model) { this(modelId, config, fieldMap); this.model = Objects.requireNonNull(model); @@ -128,6 +118,21 @@ public InferenceRescorerBuilder(StreamInput in) throws IOException { } else { fieldMap = null; } + queryWeight = in.readFloat(); + modelWeight = in.readFloat(); + scoreMode = QueryRescoreMode.readFromStream(in); + } + + void setQueryWeight(float queryWeight) { + this.queryWeight = queryWeight; + } + + void setModelWeight(float modelWeight) { + this.modelWeight = modelWeight; + } + + void setScoreMode(QueryRescoreMode scoreMode) { + this.scoreMode = scoreMode; } @Override @@ -143,6 +148,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (fieldMapPresent) { out.writeMap(fieldMap, StreamOutput::writeString, StreamOutput::writeString); } + out.writeFloat(queryWeight); + out.writeFloat(modelWeight); + scoreMode.writeTo(out); } @Override @@ -156,6 +164,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep if (fieldMap != null) { builder.field(FIELD_MAPPINGS.getPreferredName(), fieldMap); } + builder.field(QUERY_WEIGHT.getPreferredName(), queryWeight); + builder.field(MODEL_WEIGHT.getPreferredName(), modelWeight); + builder.field(SCORE_MODE.getPreferredName(), scoreMode.name().toLowerCase(Locale.ROOT)); } @Override @@ -174,7 +185,8 @@ public RescorerBuilder rewrite(QueryRewriteContext ctx if (modelSupplier.get() == null) { return this; } else { - return new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelSupplier.get()); + return copyScoringSettings(new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelSupplier.get())); + } } else { SetOnce modelHolder = new SetOnce<>(); @@ -193,20 +205,45 @@ public RescorerBuilder rewrite(QueryRewriteContext ctx )); })); - return new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelHolder::get); + return copyScoringSettings(new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelHolder::get)); } } + private InferenceRescorerBuilder copyScoringSettings(InferenceRescorerBuilder target) { + target.setQueryWeight(queryWeight); + target.setModelWeight(modelWeight); + target.setScoreMode(scoreMode); + return target; + } + @Override protected RescoreContext innerBuildContext(int windowSize, QueryShardContext context) { LocalModel m = (model != null) ? model : modelSupplier.get(); assert m != null; - return new RescoreContext(windowSize, new InferenceRescorer(m, inferenceConfig, fieldMap)); + + return new RescoreContext(windowSize, new InferenceRescorer(m, inferenceConfig, fieldMap, + scoreModeSettings())); + } + + class ScoreModeSettings { + float queryWeight; + float modelWeight; + QueryRescoreMode scoreMode; + + ScoreModeSettings(float queryWeight, float modelWeight, QueryRescoreMode scoreMode) { + this.queryWeight = queryWeight; + this.modelWeight = modelWeight; + this.scoreMode = scoreMode; + } + } + + private ScoreModeSettings scoreModeSettings() { + return new ScoreModeSettings(this.queryWeight, this.modelWeight, this.scoreMode); } @Override public final int hashCode() { - return Objects.hash(windowSize, modelId, inferenceConfig, fieldMap); + return Objects.hash(windowSize, modelId, inferenceConfig, fieldMap, queryWeight, modelWeight, scoreMode); } @Override @@ -221,84 +258,9 @@ public final boolean equals(Object obj) { return Objects.equals(windowSize, other.windowSize) && Objects.equals(modelId, other.modelId) && Objects.equals(inferenceConfig, other.inferenceConfig) && - Objects.equals(fieldMap, other.fieldMap); - } - - - private static class InferenceRescorer implements Rescorer { - - private final LocalModel model; - private final InferenceConfig inferenceConfig; - private final Map fieldMap; - - - private InferenceRescorer(LocalModel model, InferenceConfig inferenceConfig, Map fieldMap) { - this.model = model; - this.inferenceConfig = inferenceConfig; - this.fieldMap = fieldMap; - - assert inferenceConfig instanceof RegressionConfig; - } - - @Override - public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException { - - // Copy ScoreDoc[] and sort by ascending docID: - ScoreDoc[] sortedHits = topDocs.scoreDocs.clone(); - Comparator docIdComparator = Comparator.comparingInt(sd -> sd.doc); - Arrays.sort(sortedHits, docIdComparator); - - // field map is fieldname in doc -> fieldname expected by model - Set fieldsToRead = new HashSet<>(model.getFieldNames()); - for (Map.Entry entry : fieldMap.entrySet()) { - if (fieldsToRead.contains(entry.getValue())) { - // replace the model fieldname with the doc fieldname - fieldsToRead.remove(entry.getValue()); - fieldsToRead.add(entry.getKey()); - } - } - - List leaves = searcher.getIndexReader().getContext().leaves(); - Map fields = new HashMap<>(); - for (int i=0; i randomMap() { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml index f1db571e766a5..11543c6b76e3d 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml @@ -92,9 +92,11 @@ setup: "ml_rescore" : { "model_id": "a-complex-regression-model", "field_mappings": {"size": "decider"}, - "inference_config": { "regression": {} } + "inference_config": { "regression": {} }, + "model_weight": 2.0, + "query_weight": 0.0 } } } - - match: { hits.hits.0._score: 5.0 } - - match: { hits.hits.1._score: 2.0 } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._score: 4.0 } From b071af252fa33fcea349c329d42d93a1242fea53 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Fri, 7 Feb 2020 13:16:35 +0000 Subject: [PATCH 06/16] Rework walking the leaves --- .../inference/search/InferenceRescorer.java | 52 ++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java index 332980a9495e0..ee21876d08862 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -77,6 +76,53 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r List leaves = searcher.getIndexReader().getContext().leaves(); Map fields = new HashMap<>(); + + int currentReader = 0; + int endDoc = 0; + LeafReaderContext readerContext = null; + + for (int hitIndex=0; hitIndex= endDoc) { + readerContext = leaves.get(currentReader); + currentReader++; + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + + for (String field : fieldsToRead) { + SortedNumericDocValues docValuesIter = DocValues.getSortedNumeric(readerContext.reader(), field); + SortedNumericDoubleValues doubles = FieldData.sortableLongBitsToDoubles(docValuesIter); + if (doubles.advanceExact(hit.doc)) { + double val = doubles.nextValue(); + fields.put(fieldMap.getOrDefault(field, field), val); + } else if (docValuesIter.docID() == DocIdSetIterator.NO_MORE_DOCS) { + logger.warn("No more docs for field {}, doc {}", field, hit.doc); + fields.remove(field); + } else { + logger.warn("no value for field {}, doc {}", field, hit.doc); + fields.remove(field); + } + } + + InferenceResults infer = model.infer(fields, inferenceConfig); + if (infer instanceof WarningInferenceResults) { + logger.warn("inference error: " + ((WarningInferenceResults) infer).getWarning()); + // TODO how to propagate this error + } else { + SingleValueInferenceResults regressionResult = (SingleValueInferenceResults) infer; + + float combinedScore = scoreModeSettings.scoreMode.combine( + hit.score * scoreModeSettings.queryWeight, + regressionResult.value().floatValue() * scoreModeSettings.modelWeight); + + sortedHits[hitIndex] = new ScoreDoc(hit.doc, combinedScore); + } + } + + + /* for (int i=0; i Date: Fri, 7 Feb 2020 13:40:45 +0000 Subject: [PATCH 07/16] Clean up --- .../xpack/ml/MachineLearning.java | 13 ---- .../search/InferenceFetchSubPhase.java | 48 ------------- .../search/InferenceQueryBuilder.java | 68 ------------------- .../inference/search/InferenceRescorer.java | 41 ----------- 4 files changed, 170 deletions(-) delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index c32086ebf8ad5..e8cf3201bbef4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -59,7 +59,6 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.script.ScriptService; -import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; @@ -231,7 +230,6 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; -import org.elasticsearch.xpack.ml.inference.search.InferenceQueryBuilder; import org.elasticsearch.xpack.ml.inference.search.InferenceRescorerBuilder; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; @@ -366,23 +364,12 @@ protected Setting roleSetting() { }; - public List getFetchSubPhases(FetchPhaseConstructionContext context) { -// return Collections.singletonList(new InferenceFetchSubPhase()); - return Collections.emptyList(); - } - @Override public List> getRescorers() { return Collections.singletonList( new RescorerSpec<>(InferenceRescorerBuilder.NAME, InferenceRescorerBuilder::new, InferenceRescorerBuilder::fromXContent)); } - @Override - public List> getQueries() { - return Collections.singletonList( - new QuerySpec<>(InferenceQueryBuilder.NAME, InferenceQueryBuilder::new, InferenceQueryBuilder::fromXContent)); - } - @Override public Map getProcessors(Processor.Parameters parameters) { if (this.enabled == false) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java deleted file mode 100644 index b3606e1fded89..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceFetchSubPhase.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ - -package org.elasticsearch.xpack.ml.inference.search; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.document.DocumentField; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.fetch.FetchSubPhase; -import org.elasticsearch.search.internal.SearchContext; - -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -public class InferenceFetchSubPhase implements FetchSubPhase { - - private static Logger logger = LogManager.getLogger(InferenceFetchSubPhase.class); - - public InferenceFetchSubPhase() { - logger.info("creating InferenceFetchSubPhase"); - } - - public void hitExecute(SearchContext context, HitContext hitContext) throws IOException { - logger.info("hitcontex"); - - Map fields = new HashMap<>(hitContext.hit().getFields()); - fields.put("Hello", new DocumentField("teddy", List.of("ruxpin"))); - hitContext.hit().fields(fields); - } - - public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException { - // do something to a search hit - - logger.info("modifying hits"); - - for (SearchHit hit : hits) { - Map fields = new HashMap<>(hit.getFields()); - fields.put("chocolate", new DocumentField("foo", List.of("bar"))); - hit.fields(fields); - } - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java deleted file mode 100644 index 3c1f1f0f79ff3..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceQueryBuilder.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ - -package org.elasticsearch.xpack.ml.inference.search; - -import org.apache.lucene.search.Query; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.lucene.search.Queries; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.AbstractQueryBuilder; -import org.elasticsearch.index.query.QueryShardContext; - -import java.io.IOException; -import java.util.Objects; - -public class InferenceQueryBuilder extends AbstractQueryBuilder { - - public static final String NAME = "ml_magic"; - - private final String modelId; - - public static InferenceQueryBuilder fromXContent(XContentParser parser) throws IOException { - return null; - } - - public InferenceQueryBuilder(String modelId) { - this.modelId = modelId; - } - - public InferenceQueryBuilder(StreamInput in) throws IOException { - modelId = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(modelId); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - - } - - @Override - protected Query doToQuery(QueryShardContext context) throws IOException { - return Queries.newMatchAllQuery(); - } - - @Override - protected boolean doEquals(InferenceQueryBuilder other) { - return Objects.equals(this.modelId, other.modelId); - } - - @Override - protected int doHashCode() { - return Objects.hash(modelId); - } - - @Override - public String getWriteableName() { - return NAME; - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java index ee21876d08862..e99fe557846d5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -121,47 +121,6 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r } } - - /* - for (int i=0; i Date: Fri, 7 Feb 2020 14:30:56 +0000 Subject: [PATCH 08/16] Apply spotless formatting --- .../inference/search/InferenceRescorer.java | 21 ++++--- .../search/InferenceRescorerBuilder.java | 62 +++++++++++-------- .../search/InferenceRescorerBuilderTests.java | 6 +- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java index e99fe557846d5..79f68097179d3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -45,9 +45,12 @@ public class InferenceRescorer implements Rescorer { private final Map fieldMap; private final InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings; - - InferenceRescorer(LocalModel model, InferenceConfig inferenceConfig, - Map fieldMap, InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings) { + InferenceRescorer( + LocalModel model, + InferenceConfig inferenceConfig, + Map fieldMap, + InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings + ) { this.model = model; this.inferenceConfig = inferenceConfig; this.fieldMap = fieldMap; @@ -81,7 +84,7 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r int endDoc = 0; LeafReaderContext readerContext = null; - for (int hitIndex=0; hitIndex PARSER = new ConstructingObjectParser<>(NAME, - args -> new InferenceRescorerBuilder((String) args[0], (List) args[1], (Map) args[2])); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + args -> new InferenceRescorerBuilder((String) args[0], (List) args[1], (Map) args[2]) + ); static { PARSER.declareString(constructorArg(), MODEL_ID); - PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); + PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); PARSER.declareField(optionalConstructorArg(), (p, c) -> p.mapStrings(), FIELD_MAPPINGS, ObjectParser.ValueType.OBJECT); PARSER.declareFloat(InferenceRescorerBuilder::setQueryWeight, QUERY_WEIGHT); PARSER.declareFloat(InferenceRescorerBuilder::setModelWeight, MODEL_WEIGHT); - PARSER.declareString((builder, mode) -> builder.setScoreMode(QueryRescoreMode.fromString(mode)), SCORE_MODE); + PARSER.declareString((builder, mode) -> builder.setScoreMode(QueryRescoreMode.fromString(mode)), SCORE_MODE); } public static InferenceRescorerBuilder fromXContent(XContentParser parser) { @@ -96,14 +98,22 @@ private InferenceRescorerBuilder(String modelId, @Nullable List this.fieldMap = fieldMap; } - private InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, - Supplier modelSupplier) { + private InferenceRescorerBuilder( + String modelId, + @Nullable InferenceConfig config, + @Nullable Map fieldMap, + Supplier modelSupplier + ) { this(modelId, config, fieldMap); this.modelSupplier = modelSupplier; } - private InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap, - LocalModel model) { + private InferenceRescorerBuilder( + String modelId, + @Nullable InferenceConfig config, + @Nullable Map fieldMap, + LocalModel model + ) { this(modelId, config, fieldMap); this.model = Objects.requireNonNull(model); } @@ -193,16 +203,15 @@ public RescorerBuilder rewrite(QueryRewriteContext ctx ctx.registerAsyncAction(((client, actionListener) -> { TrainedModelProvider modelProvider = new TrainedModelProvider(client, ctx.getXContentRegistry()); - modelProvider.getTrainedModel(modelId, true, ActionListener.wrap( - trainedModel -> { - LocalModel model = new LocalModel(modelId, - trainedModel.ensureParsedDefinition(ctx.getXContentRegistry()).getModelDefinition(), - trainedModel.getInput()); - modelHolder.set(model); - actionListener.onResponse(null); - }, - actionListener::onFailure - )); + modelProvider.getTrainedModel(modelId, true, ActionListener.wrap(trainedModel -> { + LocalModel model = new LocalModel( + modelId, + trainedModel.ensureParsedDefinition(ctx.getXContentRegistry()).getModelDefinition(), + trainedModel.getInput() + ); + modelHolder.set(model); + actionListener.onResponse(null); + }, actionListener::onFailure)); })); return copyScoringSettings(new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelHolder::get)); @@ -221,8 +230,7 @@ protected RescoreContext innerBuildContext(int windowSize, QueryShardContext con LocalModel m = (model != null) ? model : modelSupplier.get(); assert m != null; - return new RescoreContext(windowSize, new InferenceRescorer(m, inferenceConfig, fieldMap, - scoreModeSettings())); + return new RescoreContext(windowSize, new InferenceRescorer(m, inferenceConfig, fieldMap, scoreModeSettings())); } class ScoreModeSettings { @@ -255,12 +263,12 @@ public final boolean equals(Object obj) { return false; } InferenceRescorerBuilder other = (InferenceRescorerBuilder) obj; - return Objects.equals(windowSize, other.windowSize) && - Objects.equals(modelId, other.modelId) && - Objects.equals(inferenceConfig, other.inferenceConfig) && - Objects.equals(fieldMap, other.fieldMap) && - Objects.equals(queryWeight, other.queryWeight) && - Objects.equals(modelWeight, other.modelWeight) && - Objects.equals(scoreMode, other.scoreMode); + return Objects.equals(windowSize, other.windowSize) + && Objects.equals(modelId, other.modelId) + && Objects.equals(inferenceConfig, other.inferenceConfig) + && Objects.equals(fieldMap, other.fieldMap) + && Objects.equals(queryWeight, other.queryWeight) + && Objects.equals(modelWeight, other.modelWeight) + && Objects.equals(scoreMode, other.scoreMode); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java index 210fe81639376..685779bccb780 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java @@ -46,8 +46,8 @@ protected InferenceRescorerBuilder createTestInstance() { InferenceRescorerBuilder builder = new InferenceRescorerBuilder(randomAlphaOfLength(8), config, randomMap()); if (randomBoolean()) { - builder.setQueryWeight((float)randomDoubleBetween(0.0, 1.0, true)); - builder.setModelWeight((float)randomDoubleBetween(0.0, 2.0, true)); + builder.setQueryWeight((float) randomDoubleBetween(0.0, 1.0, true)); + builder.setModelWeight((float) randomDoubleBetween(0.0, 2.0, true)); builder.setScoreMode(randomFrom(QueryRescoreMode.values())); } @@ -57,7 +57,7 @@ protected InferenceRescorerBuilder createTestInstance() { private Map randomMap() { int numEntries = randomIntBetween(0, 6); Map result = new HashMap<>(); - for (int i=0; i Date: Fri, 7 Feb 2020 15:28:52 +0000 Subject: [PATCH 09/16] Tidy up --- .../org/elasticsearch/client/TransformIT.java | 18 ++++++++++++++++++ .../ml/inference/search/InferenceRescorer.java | 9 ++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java index 99afbe379d231..94341c41685f0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/TransformIT.java @@ -163,6 +163,24 @@ public void cleanUpTransformsAndLogAudits() throws Exception { transformsToClean = new ArrayList<>(); waitForPendingTasks(adminClient()); + + // using '*' to make this lenient and do not fail if the audit index does not exist + SearchRequest searchRequest = new SearchRequest(".transform-notifications-*"); + searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(100).sort("timestamp", SortOrder.ASC)); + + for (SearchHit hit : searchAll(searchRequest)) { + Map source = hit.getSourceAsMap(); + String level = (String) source.getOrDefault("level", "info"); + logger.log( + Level.getLevel(level.toUpperCase(Locale.ROOT)), + "Transform audit: [{}] [{}] [{}] [{}]", + Instant.ofEpochMilli((long) source.getOrDefault("timestamp", 0)), + source.getOrDefault("transform_id", "n/a"), + source.getOrDefault("message", "n/a"), + source.getOrDefault("node_name", "n/a") + ); + } + } public void testCreateDelete() throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java index 79f68097179d3..2feceff8e68e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -16,6 +16,7 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.index.fielddata.FieldData; import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; import org.elasticsearch.search.rescore.RescoreContext; @@ -67,8 +68,8 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r Comparator docIdComparator = Comparator.comparingInt(sd -> sd.doc); Arrays.sort(sortedHits, docIdComparator); - // field map is fieldname in doc -> fieldname expected by model Set fieldsToRead = new HashSet<>(model.getFieldNames()); + // field map is fieldname in doc -> fieldname expected by model for (Map.Entry entry : fieldMap.entrySet()) { if (fieldsToRead.contains(entry.getValue())) { // replace the model fieldname with the doc fieldname @@ -88,6 +89,7 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r ScoreDoc hit = sortedHits[hitIndex]; int docId = hit.doc; + // get the context for this docId while (docId >= endDoc) { readerContext = leaves.get(currentReader); currentReader++; @@ -111,8 +113,9 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r InferenceResults infer = model.infer(fields, inferenceConfig); if (infer instanceof WarningInferenceResults) { - logger.warn("inference error: " + ((WarningInferenceResults) infer).getWarning()); - // TODO how to propagate this error + String message = ((WarningInferenceResults) infer).getWarning(); + logger.warn("inference error: " + message); + throw new ElasticsearchException(message); } else { SingleValueInferenceResults regressionResult = (SingleValueInferenceResults) infer; From 47fbff8ed5ff421e15540f39eced7eb2d1cd5e27 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 1 Jun 2020 21:59:55 +0100 Subject: [PATCH 10/16] Use model loading service --- .../xpack/ml/MachineLearning.java | 6 +- .../inference/search/InferenceRescorer.java | 6 +- .../search/InferenceRescorerBuilder.java | 113 +++++++----------- 3 files changed, 48 insertions(+), 77 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index e8cf3201bbef4..54e7e2da31d38 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -367,7 +367,9 @@ protected Setting roleSetting() { @Override public List> getRescorers() { return Collections.singletonList( - new RescorerSpec<>(InferenceRescorerBuilder.NAME, InferenceRescorerBuilder::new, InferenceRescorerBuilder::fromXContent)); + new RescorerSpec<>(InferenceRescorerBuilder.NAME, + in -> new InferenceRescorerBuilder(in, modelLoadingService), + parser -> InferenceRescorerBuilder.fromXContent(parser, modelLoadingService))); } @Override @@ -458,6 +460,7 @@ public Set getRoles() { private final SetOnce dataFrameAnalyticsAuditor = new SetOnce<>(); private final SetOnce memoryTracker = new SetOnce<>(); private final SetOnce mlUpgradeModeActionFilter = new SetOnce<>(); + private final SetOnce modelLoadingService = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -686,6 +689,7 @@ public Collection createComponents(Client client, ClusterService cluster trainedModelStatsService, settings, clusterService.getNodeName()); + this.modelLoadingService.set(modelLoadingService); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java index 2feceff8e68e7..a519f706527d7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; -import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; +import org.elasticsearch.xpack.ml.inference.loadingservice.Model; import java.io.IOException; import java.util.Arrays; @@ -41,13 +41,13 @@ public class InferenceRescorer implements Rescorer { private static final Logger logger = LogManager.getLogger(InferenceRescorer.class); - private final LocalModel model; + private final Model model; private final InferenceConfig inferenceConfig; private final Map fieldMap; private final InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings; InferenceRescorer( - LocalModel model, + Model model, InferenceConfig inferenceConfig, Map fieldMap, InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java index 8514906cfa8a6..aaadff7028247 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java @@ -22,11 +22,10 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; -import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import java.io.IOException; -import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -52,73 +51,57 @@ public class InferenceRescorerBuilder extends RescorerBuilder PARSER = new ConstructingObjectParser<>( - NAME, - args -> new InferenceRescorerBuilder((String) args[0], (List) args[1], (Map) args[2]) - ); + private static final ConstructingObjectParser> PARSER = + new ConstructingObjectParser<>(NAME, false, + (args, context) -> + new InferenceRescorerBuilder((String) args[0], context, (InferenceConfig) args[1], (Map) args[2]) + ); static { PARSER.declareString(constructorArg(), MODEL_ID); - PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); + PARSER.declareNamedObject(optionalConstructorArg(), (p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); PARSER.declareField(optionalConstructorArg(), (p, c) -> p.mapStrings(), FIELD_MAPPINGS, ObjectParser.ValueType.OBJECT); PARSER.declareFloat(InferenceRescorerBuilder::setQueryWeight, QUERY_WEIGHT); PARSER.declareFloat(InferenceRescorerBuilder::setModelWeight, MODEL_WEIGHT); PARSER.declareString((builder, mode) -> builder.setScoreMode(QueryRescoreMode.fromString(mode)), SCORE_MODE); } - public static InferenceRescorerBuilder fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + public static InferenceRescorerBuilder fromXContent(XContentParser parser, SetOnce modelLoadingService) { + return PARSER.apply(parser, modelLoadingService); } private final String modelId; + private final SetOnce modelLoadingService; private final InferenceConfig inferenceConfig; private final Map fieldMap; - private LocalModel model; - private Supplier modelSupplier; + private Model model; private float queryWeight = DEFAULT_QUERY_WEIGHT; private float modelWeight = DEFAULT_MODEL_WEIGHT; private QueryRescoreMode scoreMode = DEFAULT_SCORE_MODE; - private InferenceRescorerBuilder(String modelId, @Nullable List config, @Nullable Map fieldMap) { - this.modelId = modelId; - if (config != null) { - assert config.size() == 1; - this.inferenceConfig = config.get(0); - } else { - this.inferenceConfig = null; - } - this.fieldMap = fieldMap; - } - - InferenceRescorerBuilder(String modelId, @Nullable InferenceConfig config, @Nullable Map fieldMap) { + public InferenceRescorerBuilder(String modelId, + SetOnce modelLoadingService, + InferenceConfig config, + @Nullable Map fieldMap) { this.modelId = modelId; + this.modelLoadingService = modelLoadingService; this.inferenceConfig = config; this.fieldMap = fieldMap; } - private InferenceRescorerBuilder( - String modelId, - @Nullable InferenceConfig config, - @Nullable Map fieldMap, - Supplier modelSupplier + private InferenceRescorerBuilder(String modelId, + SetOnce modelLoadingService, + @Nullable InferenceConfig config, + @Nullable Map fieldMap, + Supplier modelSupplier ) { - this(modelId, config, fieldMap); - this.modelSupplier = modelSupplier; + this(modelId, modelLoadingService, config, fieldMap); + this.model = modelSupplier.get(); } - private InferenceRescorerBuilder( - String modelId, - @Nullable InferenceConfig config, - @Nullable Map fieldMap, - LocalModel model - ) { - this(modelId, config, fieldMap); - this.model = Objects.requireNonNull(model); - } - - public InferenceRescorerBuilder(StreamInput in) throws IOException { + public InferenceRescorerBuilder(StreamInput in, SetOnce modelLoadingService) throws IOException { super(in); modelId = in.readString(); inferenceConfig = in.readOptionalNamedWriteable(InferenceConfig.class); @@ -131,6 +114,8 @@ public InferenceRescorerBuilder(StreamInput in) throws IOException { queryWeight = in.readFloat(); modelWeight = in.readFloat(); scoreMode = QueryRescoreMode.readFromStream(in); + + this.modelLoadingService = modelLoadingService; } void setQueryWeight(float queryWeight) { @@ -147,10 +132,6 @@ void setScoreMode(QueryRescoreMode scoreMode) { @Override protected void doWriteTo(StreamOutput out) throws IOException { - if (modelSupplier != null) { - throw new IllegalStateException("can't serialize model supplier. Missing a rewriteAndFetch?"); - } - out.writeString(modelId); out.writeOptionalNamedWriteable(inferenceConfig); boolean fieldMapPresent = fieldMap != null; @@ -191,30 +172,17 @@ public RescorerBuilder rewrite(QueryRewriteContext ctx if (model != null) { return this; - } else if (modelSupplier != null) { - if (modelSupplier.get() == null) { - return this; - } else { - return copyScoringSettings(new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelSupplier.get())); - - } } else { - SetOnce modelHolder = new SetOnce<>(); - - ctx.registerAsyncAction(((client, actionListener) -> { - TrainedModelProvider modelProvider = new TrainedModelProvider(client, ctx.getXContentRegistry()); - modelProvider.getTrainedModel(modelId, true, ActionListener.wrap(trainedModel -> { - LocalModel model = new LocalModel( - modelId, - trainedModel.ensureParsedDefinition(ctx.getXContentRegistry()).getModelDefinition(), - trainedModel.getInput() - ); - modelHolder.set(model); - actionListener.onResponse(null); - }, actionListener::onFailure)); - })); - - return copyScoringSettings(new InferenceRescorerBuilder(modelId, inferenceConfig, fieldMap, modelHolder::get)); + SetOnce modelHolder = new SetOnce<>(); + + ctx.registerAsyncAction(((client, actionListener) -> + modelLoadingService.get().getModel(modelId, ActionListener.wrap( + modelHolder::set, + actionListener::onFailure)) + )); + + return copyScoringSettings( + new InferenceRescorerBuilder(modelId, modelLoadingService, inferenceConfig, fieldMap, modelHolder::get)); } } @@ -227,13 +195,12 @@ private InferenceRescorerBuilder copyScoringSettings(InferenceRescorerBuilder ta @Override protected RescoreContext innerBuildContext(int windowSize, QueryShardContext context) { - LocalModel m = (model != null) ? model : modelSupplier.get(); - assert m != null; + assert model != null; - return new RescoreContext(windowSize, new InferenceRescorer(m, inferenceConfig, fieldMap, scoreModeSettings())); + return new RescoreContext(windowSize, new InferenceRescorer(model, inferenceConfig, fieldMap, scoreModeSettings())); } - class ScoreModeSettings { + static class ScoreModeSettings { float queryWeight; float modelWeight; QueryRescoreMode scoreMode; From 65e33b7ad40a7b6019a4dc5bc8ae49d45e574e44 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 2 Jun 2020 08:18:30 +0100 Subject: [PATCH 11/16] Use the LocalModel interface instead of Model --- .../TransportInternalInferModelAction.java | 4 +-- .../inference/loadingservice/LocalModel.java | 34 +++++++++---------- .../loadingservice/ModelLoadingService.java | 24 ++++++------- .../ModelLoadingServiceTests.java | 24 ++++++------- 4 files changed, 42 insertions(+), 44 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index d229f4decbee7..698b353aefbc2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -20,7 +20,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; @@ -52,7 +52,7 @@ protected void doExecute(Task task, Request request, ActionListener li Response.Builder responseBuilder = Response.builder(); - ActionListener getModelListener = ActionListener.wrap( + ActionListener getModelListener = ActionListener.wrap( model -> { TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index f902efc787e94..9758b1f8a078b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -32,7 +32,6 @@ public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; private final String modelId; - private final String nodeId; private final Set fieldNames; private final Map defaultFieldMap; private final InferenceStats.Accumulator statsAccumulator; @@ -50,7 +49,6 @@ public LocalModel(String modelId, TrainedModelStatsService trainedModelStatsService ) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; - this.nodeId = nodeId; this.fieldNames = new HashSet<>(input.getFieldNames()); this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId); this.trainedModelStatsService = trainedModelStatsService; @@ -103,14 +101,23 @@ void persistStats(boolean flush) { @Override public void infer(Map fields, InferenceConfigUpdate update, ActionListener listener) { + try { + InferenceResults result = infer(fields, update); + listener.onResponse(result); + } catch (Exception e) { + listener.onFailure(e); + } + } + + public InferenceResults infer(Map fields, InferenceConfigUpdate update) { if (update.isSupported(this.inferenceConfig) == false) { - listener.onFailure(ExceptionsHelper.badRequestException( + throw ExceptionsHelper.badRequestException( "Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]", this.modelId, this.inferenceConfig.getName(), - update.getName())); - return; + update.getName()); } + try { statsAccumulator.incInference(); currentInferenceCount.increment(); @@ -123,26 +130,17 @@ public void infer(Map fields, InferenceConfigUpdate update, Acti if (shouldPersistStats) { persistStats(false); } - listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); - return; + return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)); } InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig)); if (shouldPersistStats) { persistStats(false); } - listener.onResponse(inferenceResults); + + return inferenceResults; } catch (Exception e) { statsAccumulator.incFailure(); - listener.onFailure(e); - } - } - - public InferenceResults infer(Map fields, InferenceConfig config) { - if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { - return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)); - } else { - return trainedModelDefinition.infer(fields, config); + throw e; } } - } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 5053972e50bec..dffa3342585e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -88,7 +88,7 @@ public class ModelLoadingService implements ClusterStateListener { private final TrainedModelStatsService modelStatsService; private final Cache localModelCache; private final Set referencedModels = new HashSet<>(); - private final Map>> loadingListeners = new HashMap<>(); + private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; private final Set shouldNotAudit; private final ThreadPool threadPool; @@ -140,7 +140,7 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider, * @param modelId the model to get * @param modelActionListener the listener to alert when the model has been retrieved. */ - public void getModel(String modelId, ActionListener modelActionListener) { + public void getModel(String modelId, ActionListener modelActionListener) { LocalModel cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); @@ -178,9 +178,9 @@ public void getModel(String modelId, ActionListener modelActionListener) * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded * Returns false if the model is not loaded or actively being loaded */ - private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { synchronized (loadingListeners) { - Model cachedModel = localModelCache.get(modelId); + LocalModel cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); return true; @@ -219,7 +219,7 @@ private void loadModel(String modelId) { } private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) throws IOException { - Queue> listeners; + Queue> listeners; trainedModelConfig.ensureParsedDefinition(namedXContentRegistry); InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : @@ -242,13 +242,13 @@ private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelCo localModelCache.put(modelId, loadedModel); shouldNotAudit.remove(modelId); } // synchronized (loadingListeners) - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onResponse(loadedModel); } } private void handleLoadFailure(String modelId, Exception failure) { - Queue> listeners; + Queue> listeners; synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); if (listeners == null) { @@ -257,7 +257,7 @@ private void handleLoadFailure(String modelId, Exception failure) { } // synchronized (loadingListeners) // If we failed to load and there were listeners present, that means that this model is referenced by a processor // Alert the listeners to the failure - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onFailure(failure); } } @@ -294,7 +294,7 @@ public void clusterChanged(ClusterChangedEvent event) { return; } // The listeners still waiting for a model and we are canceling the load? - List>>> drainWithFailure = new ArrayList<>(); + List>>> drainWithFailure = new ArrayList<>(); Set referencedModelsBeforeClusterState = null; Set loadingModelBeforeClusterState = null; Set removedModels = null; @@ -337,11 +337,11 @@ public void clusterChanged(ClusterChangedEvent event) { referencedModels); } } - for (Tuple>> modelAndListeners : drainWithFailure) { + for (Tuple>> modelAndListeners : drainWithFailure) { final String msg = new ParameterizedMessage( "Cancelling load of model [{}] as it is no longer referenced by a pipeline", modelAndListeners.v1()).getFormat(); - for (ActionListener listener : modelAndListeners.v2()) { + for (ActionListener listener : modelAndListeners.v2()) { listener.onFailure(new ElasticsearchException(msg)); } } @@ -430,7 +430,7 @@ private static InferenceConfig inferenceConfigFromTargetType(TargetType targetTy * @param modelId Model Id * @param modelLoadedListener To be notified */ - void addModelLoadedListener(String modelId, ActionListener modelLoadedListener) { + void addModelLoadedListener(String modelId, ActionListener modelLoadedListener) { synchronized (loadingListeners) { loadingListeners.compute(modelId, (modelKey, listenerQueue) -> { if (listenerQueue == null) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 8753db3e878ea..033d50c667dbb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -125,7 +125,7 @@ public void testGetCachedModels() throws Exception { String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -138,7 +138,7 @@ public void testGetCachedModels() throws Exception { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -190,7 +190,7 @@ public void testMaxCachedLimitReached() throws Exception { for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -213,7 +213,7 @@ public boolean matches(final Object o) { // Load model 3, should invalidate 1 and 2 for(int i = 0; i < 10; i++) { - PlainActionFuture future3 = new PlainActionFuture<>(); + PlainActionFuture future3 = new PlainActionFuture<>(); modelLoadingService.getModel(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } @@ -234,7 +234,7 @@ public boolean matches(final Object o) { // Load model 1, should invalidate 3 for(int i = 0; i < 10; i++) { - PlainActionFuture future1 = new PlainActionFuture<>(); + PlainActionFuture future1 = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } @@ -248,7 +248,7 @@ public boolean matches(final Object o) { // Load model 2 for(int i = 0; i < 10; i++) { - PlainActionFuture future2 = new PlainActionFuture<>(); + PlainActionFuture future2 = new PlainActionFuture<>(); modelLoadingService.getModel(model2, future2); assertThat(future2.get(), is(not(nullValue()))); } @@ -259,7 +259,7 @@ public boolean matches(final Object o) { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -286,7 +286,7 @@ public void testWhenCacheEnabledButNotIngestNode() throws Exception { modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); for(int i = 0; i < 10; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future); assertThat(future.get(), is(not(nullValue()))); } @@ -309,7 +309,7 @@ public void testGetCachedMissingModel() throws Exception { "test-node"); modelLoadingService.clusterChanged(ingestChangedEvent(model)); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { @@ -336,7 +336,7 @@ public void testGetMissingModel() { Settings.EMPTY, "test-node"); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { future.get(); @@ -360,7 +360,7 @@ public void testGetModelEagerly() throws Exception { "test-node"); for(int i = 0; i < 3; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -458,7 +458,7 @@ private void onFailure(Exception e) { fail(e.getMessage()); } - ActionListener actionListener() { + ActionListener actionListener() { return ActionListener.wrap(this::onModelLoaded, this::onFailure); } } From 31178b123c888a96a1cde195ac69f0295d2da769 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 2 Jun 2020 08:26:41 +0100 Subject: [PATCH 12/16] fix tests --- .../inference/search/InferenceRescorer.java | 19 +++++------- .../search/InferenceRescorerBuilder.java | 30 +++++++++++++++---- .../search/InferenceRescorerBuilderTests.java | 14 +++++++-- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java index a519f706527d7..5e95bc394fd23 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorer.java @@ -24,9 +24,8 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; -import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import java.io.IOException; import java.util.Arrays; @@ -41,23 +40,21 @@ public class InferenceRescorer implements Rescorer { private static final Logger logger = LogManager.getLogger(InferenceRescorer.class); - private final Model model; - private final InferenceConfig inferenceConfig; + private final LocalModel model; + private final InferenceConfigUpdate update; private final Map fieldMap; private final InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings; InferenceRescorer( - Model model, - InferenceConfig inferenceConfig, + LocalModel model, + InferenceConfigUpdate update, Map fieldMap, InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings ) { this.model = model; - this.inferenceConfig = inferenceConfig; + this.update = update; this.fieldMap = fieldMap; this.scoreModeSettings = scoreModeSettings; - - assert inferenceConfig instanceof RegressionConfig; } @Override @@ -111,7 +108,7 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r } } - InferenceResults infer = model.infer(fields, inferenceConfig); + InferenceResults infer = model.infer(fields, update); if (infer instanceof WarningInferenceResults) { String message = ((WarningInferenceResults) infer).getWarning(); logger.warn("inference error: " + message); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java index aaadff7028247..694331c785dc0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java @@ -21,8 +21,14 @@ import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import java.io.IOException; @@ -59,7 +65,8 @@ public class InferenceRescorerBuilder extends RescorerBuilder p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG); + PARSER.declareNamedObject(optionalConstructorArg(), + (p, c, n) -> p.namedObject(StrictlyParsedInferenceConfig.class, n, c), INFERENCE_CONFIG); PARSER.declareField(optionalConstructorArg(), (p, c) -> p.mapStrings(), FIELD_MAPPINGS, ObjectParser.ValueType.OBJECT); PARSER.declareFloat(InferenceRescorerBuilder::setQueryWeight, QUERY_WEIGHT); PARSER.declareFloat(InferenceRescorerBuilder::setModelWeight, MODEL_WEIGHT); @@ -75,7 +82,7 @@ public static InferenceRescorerBuilder fromXContent(XContentParser parser, SetOn private final InferenceConfig inferenceConfig; private final Map fieldMap; - private Model model; + private LocalModel model; private float queryWeight = DEFAULT_QUERY_WEIGHT; private float modelWeight = DEFAULT_MODEL_WEIGHT; @@ -95,7 +102,7 @@ private InferenceRescorerBuilder(String modelId, SetOnce modelLoadingService, @Nullable InferenceConfig config, @Nullable Map fieldMap, - Supplier modelSupplier + Supplier modelSupplier ) { this(modelId, modelLoadingService, config, fieldMap); this.model = modelSupplier.get(); @@ -173,7 +180,7 @@ public RescorerBuilder rewrite(QueryRewriteContext ctx if (model != null) { return this; } else { - SetOnce modelHolder = new SetOnce<>(); + SetOnce modelHolder = new SetOnce<>(); ctx.registerAsyncAction(((client, actionListener) -> modelLoadingService.get().getModel(modelId, ActionListener.wrap( @@ -197,7 +204,18 @@ private InferenceRescorerBuilder copyScoringSettings(InferenceRescorerBuilder ta protected RescoreContext innerBuildContext(int windowSize, QueryShardContext context) { assert model != null; - return new RescoreContext(windowSize, new InferenceRescorer(model, inferenceConfig, fieldMap, scoreModeSettings())); + InferenceConfigUpdate update; + if (inferenceConfig == null) { + update = new EmptyConfigUpdate(); + } else if (inferenceConfig instanceof RegressionConfig) { + update = RegressionConfigUpdate.fromConfig((RegressionConfig)inferenceConfig); + } else { + // TODO better message + throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", + inferenceConfig.getName(), RegressionConfig.NAME.getPreferredName()); + } + + return new RescoreContext(windowSize, new InferenceRescorer(model, update, fieldMap, scoreModeSettings())); } static class ScoreModeSettings { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java index 685779bccb780..fd5ce7de75082 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilderTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.inference.search; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -16,20 +17,23 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import java.util.HashMap; import java.util.Map; +import static org.mockito.Mockito.mock; + public class InferenceRescorerBuilderTests extends AbstractSerializingTestCase { @Override protected InferenceRescorerBuilder doParseInstance(XContentParser parser) { - return InferenceRescorerBuilder.fromXContent(parser); + return InferenceRescorerBuilder.fromXContent(parser, new SetOnce<>(mock(ModelLoadingService.class))); } @Override protected Writeable.Reader instanceReader() { - return InferenceRescorerBuilder::new; + return in -> new InferenceRescorerBuilder(in, new SetOnce<>(mock(ModelLoadingService.class))); } @Override @@ -44,7 +48,11 @@ protected InferenceRescorerBuilder createTestInstance() { } } - InferenceRescorerBuilder builder = new InferenceRescorerBuilder(randomAlphaOfLength(8), config, randomMap()); + InferenceRescorerBuilder builder = new InferenceRescorerBuilder( + randomAlphaOfLength(8), + new SetOnce<>(mock(ModelLoadingService.class)), + config, randomMap()); + if (randomBoolean()) { builder.setQueryWeight((float) randomDoubleBetween(0.0, 1.0, true)); builder.setModelWeight((float) randomDoubleBetween(0.0, 2.0, true)); From ec9582399d1aa30c541fd7b76fe25a328a6fdd33 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 2 Jun 2020 09:35:35 +0100 Subject: [PATCH 13/16] Fix yml test --- .../search/InferenceRescorerBuilder.java | 35 +++++++++++++++---- .../rest-api-spec/test/ml/rescore.yml | 1 + 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java index 694331c785dc0..5bd3b62c96d11 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferenceRescorerBuilder.java @@ -83,6 +83,7 @@ public static InferenceRescorerBuilder fromXContent(XContentParser parser, SetOn private final Map fieldMap; private LocalModel model; + private Supplier modelSupplier; private float queryWeight = DEFAULT_QUERY_WEIGHT; private float modelWeight = DEFAULT_MODEL_WEIGHT; @@ -105,7 +106,17 @@ private InferenceRescorerBuilder(String modelId, Supplier modelSupplier ) { this(modelId, modelLoadingService, config, fieldMap); - this.model = modelSupplier.get(); + this.modelSupplier = modelSupplier; + } + + private InferenceRescorerBuilder(String modelId, + SetOnce modelLoadingService, + @Nullable InferenceConfig config, + @Nullable Map fieldMap, + LocalModel model + ) { + this(modelId, modelLoadingService, config, fieldMap); + this.model = model; } public InferenceRescorerBuilder(StreamInput in, SetOnce modelLoadingService) throws IOException { @@ -174,23 +185,33 @@ public String getWriteableName() { @Override public RescorerBuilder rewrite(QueryRewriteContext ctx) { - assert modelId != null; - if (model != null) { - return this; - } else { + if (modelSupplier != null) { + LocalModel m = modelSupplier.get(); + if (m == null) { + return this; + } else { + return copyScoringSettings( + new InferenceRescorerBuilder(modelId, modelLoadingService, inferenceConfig, fieldMap, m)); + } + } else if (model == null) { + SetOnce modelHolder = new SetOnce<>(); ctx.registerAsyncAction(((client, actionListener) -> modelLoadingService.get().getModel(modelId, ActionListener.wrap( - modelHolder::set, + m -> { + modelHolder.set(m); + actionListener.onResponse(null); + }, actionListener::onFailure)) )); return copyScoringSettings( new InferenceRescorerBuilder(modelId, modelLoadingService, inferenceConfig, fieldMap, modelHolder::get)); } + return this; } private InferenceRescorerBuilder copyScoringSettings(InferenceRescorerBuilder target) { @@ -208,7 +229,7 @@ protected RescoreContext innerBuildContext(int windowSize, QueryShardContext con if (inferenceConfig == null) { update = new EmptyConfigUpdate(); } else if (inferenceConfig instanceof RegressionConfig) { - update = RegressionConfigUpdate.fromConfig((RegressionConfig)inferenceConfig); + update = RegressionConfigUpdate.fromConfig((RegressionConfig) inferenceConfig); } else { // TODO better message throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml index 11543c6b76e3d..bfbcbefd97bf1 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml @@ -10,6 +10,7 @@ setup: { "description": "super complex model for tests", "input": {"field_names": ["decider"]}, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { From e7b1e950a737a4dfdd838ee5fa8b44e6a4f9b1ac Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 2 Jun 2020 09:36:14 +0100 Subject: [PATCH 14/16] Add EmptyConfigUpdate class --- .../trainedmodel/EmptyConfigUpdate.java | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java new file mode 100644 index 0000000000000..ff50aaa6ab558 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; + +public class EmptyConfigUpdate implements InferenceConfigUpdate { + public static final ParseField NAME = new ParseField("empty"); + + private static final ObjectParser PARSER = + new ObjectParser<>(NAME.getPreferredName(), EmptyConfigUpdate::new); + + public static EmptyConfigUpdate fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public EmptyConfigUpdate() { + } + + public EmptyConfigUpdate(StreamInput in) { + } + + @Override + public InferenceConfig apply(InferenceConfig originalConfig) { + return originalConfig; + } + + @Override + public InferenceConfig toConfig() { + return RegressionConfig.EMPTY_PARAMS; + } + + @Override + public boolean isSupported(InferenceConfig config) { + return true; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + // Equal if o is not null and the same class + return (o == null || getClass() != o.getClass()) == false; + } + + @Override + public int hashCode() { + return super.hashCode(); + } +} From ddc5632d2c64eb6b0b07d2877bdd6253a7bab8db Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 2 Jun 2020 13:28:14 +0100 Subject: [PATCH 15/16] Add bulk to list of rest specs --- .../plugin/ml/qa/ml-with-security/build.gradle | 2 +- .../resources/rest-api-spec/test/ml/rescore.yml | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 84ea1372cb1aa..2f092e87a6915 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -11,7 +11,7 @@ dependencies { // bring in machine learning rest test suite restResources { restApi { - includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest' + includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest', 'bulk' includeXpack 'ml', 'cat' } restTests { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml index bfbcbefd97bf1..86ac893d62b0d 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/rescore.yml @@ -72,18 +72,21 @@ setup: bulk: index: store refresh: true - body: | - { "index": {} } - { "goods": "television", "size": 32.0 } - { "index": {} } - { "goods": "VCR", "size": 0 } - { "index": {} } - { "goods": "widescreen television", "size": 40.0 } + body: + - '{ "index": {} }' + - '{ "goods": "television", "size": 32.0 }' + - '{ "index": {} }' + - '{ "goods": "VCR", "size": 0 }' + - '{ "index": {} }' + - '{ "goods": "widescreen television", "size": 40.0 }' --- "Test rescore": - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + Content-Type: application/json search: index: store body: | From af13b23841c4927543a6a6e590460c046c8112fe Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 2 Jun 2020 14:30:08 +0100 Subject: [PATCH 16/16] Blacklist rest test that doesn't throw security exception --- x-pack/plugin/ml/qa/ml-with-security/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 2f092e87a6915..9cc8e0c90e7ec 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -206,6 +206,7 @@ integTest.runner { 'ml/validate_detector/Test invalid detector', 'ml/delete_forecast/Test delete on _all forecasts not allow no forecasts', 'ml/delete_forecast/Test delete forecast on missing forecast', + 'ml/rescore/Test rescore', 'ml/set_upgrade_mode/Attempt to open job when upgrade_mode is enabled', 'ml/set_upgrade_mode/Setting upgrade_mode to enabled', 'ml/set_upgrade_mode/Setting upgrade mode to disabled from enabled',