Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass in ModelLoadingService to the inference rescorer. #8

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Setting;
Expand All @@ -33,6 +34,7 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.index.analysis.TokenizerFactory;
Expand All @@ -47,6 +49,7 @@
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
Expand Down Expand Up @@ -211,6 +214,7 @@
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.inference.search.InferenceRescorerBuilder;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
Expand Down Expand Up @@ -318,7 +322,7 @@

import static java.util.Collections.emptyList;

public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
Expand All @@ -342,6 +346,22 @@ protected Setting<Boolean> roleSetting() {

};

@Override
public List<RescorerSpec<?>> getRescorers() {
return Collections.singletonList(
new RescorerSpec<>(InferenceRescorerBuilder.NAME, this::rescorerFromStream, this::rescorerFromXContent));
}

private InferenceRescorerBuilder rescorerFromStream(StreamInput in) throws IOException {
InferenceRescorerBuilder result = new InferenceRescorerBuilder(in, modelLoadingService);
return result;
}

private InferenceRescorerBuilder rescorerFromXContent(XContentParser parser) {
InferenceRescorerBuilder result = InferenceRescorerBuilder.fromXContent(parser, modelLoadingService);
return result;
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
if (this.enabled == false) {
Expand Down Expand Up @@ -414,6 +434,7 @@ public Set<DiscoveryNodeRole> getRoles() {
private final SetOnce<DataFrameAnalyticsManager> dataFrameAnalyticsManager = new SetOnce<>();
private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();

public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
Expand Down Expand Up @@ -628,6 +649,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
clusterService,
xContentRegistry,
settings);
this.modelLoadingService.set(modelLoadingService);

// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;

import java.util.HashSet;
Expand Down Expand Up @@ -44,6 +44,11 @@ public String getModelId() {
return modelId;
}

@Override
public Set<String> getFieldNames() {
return fieldNames;
}

@Override
public String getResultsType() {
switch (trainedModelDefinition.getTrainedModel().targetType()) {
Expand All @@ -61,15 +66,19 @@ public String getResultsType() {
@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
try {
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
}

listener.onResponse(trainedModelDefinition.infer(fields, config));
listener.onResponse(infer(fields, config));
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
return new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId));
} else {
return trainedModelDefinition.infer(fields, config);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;

import java.util.Map;
import java.util.Set;

public interface Model {

String getResultsType();

void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener);

InferenceResults infer(Map<String, Object> fields, InferenceConfig config);

String getModelId();

Set<String> getFieldNames();

}
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,10 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio
listener::onFailure
);

executeAsyncWithOrigin(client,
ML_ORIGIN,
MultiSearchResponse response = client.execute(
MultiSearchAction.INSTANCE,
multiSearchRequestBuilder.request(),
multiSearchResponseActionListener);
multiSearchRequestBuilder.request()).actionGet();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same problem we have with the FetchSubPhase PR, we have make a blocking call somewhere

multiSearchResponseActionListener.onResponse(response);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.inference.search;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.index.fielddata.FieldData;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.rescore.Rescorer;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class InferenceRescorer implements Rescorer {

private static final Logger logger = LogManager.getLogger(InferenceRescorer.class);

private final Model model;
private final InferenceConfig inferenceConfig;
private final Map<String, String> fieldMap;
private final InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings;

InferenceRescorer(Model model,
InferenceConfig inferenceConfig,
Map<String, String> fieldMap,
InferenceRescorerBuilder.ScoreModeSettings scoreModeSettings) {
this.model = model;
this.inferenceConfig = inferenceConfig;
this.fieldMap = fieldMap;
this.scoreModeSettings = scoreModeSettings;

assert inferenceConfig instanceof RegressionConfig;
}

@Override
public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException {

// Copy ScoreDoc[] and sort by ascending docID:
ScoreDoc[] sortedHits = topDocs.scoreDocs.clone();
Comparator<ScoreDoc> docIdComparator = Comparator.comparingInt(sd -> sd.doc);
Arrays.sort(sortedHits, docIdComparator);


Set<String> fieldsToRead = new HashSet<>(model.getFieldNames());
// field map is fieldname in doc -> fieldname expected by model
for (Map.Entry<String, String> entry : fieldMap.entrySet()) {
if (fieldsToRead.contains(entry.getValue())) {
// replace the model fieldname with the doc fieldname
fieldsToRead.remove(entry.getValue());
fieldsToRead.add(entry.getKey());
}
}

List<LeafReaderContext> leaves = searcher.getIndexReader().getContext().leaves();
Map<String, Object> fields = new HashMap<>();

int currentReader = 0;
int endDoc = 0;
LeafReaderContext readerContext = null;

for (int hitIndex = 0; hitIndex < sortedHits.length; hitIndex++) {
ScoreDoc hit = sortedHits[hitIndex];
int docId = hit.doc;

// get the context for this docId
while (docId >= endDoc) {
readerContext = leaves.get(currentReader);
currentReader++;
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
}

for (String field : fieldsToRead) {
SortedNumericDocValues docValuesIter = DocValues.getSortedNumeric(readerContext.reader(), field);
SortedNumericDoubleValues doubles = FieldData.sortableLongBitsToDoubles(docValuesIter);
if (doubles.advanceExact(hit.doc)) {
double val = doubles.nextValue();
fields.put(fieldMap.getOrDefault(field, field), val);
} else if (docValuesIter.docID() == DocIdSetIterator.NO_MORE_DOCS) {
logger.warn("No more docs for field {}, doc {}", field, hit.doc);
fields.remove(field);
} else {
logger.warn("no value for field {}, doc {}", field, hit.doc);
fields.remove(field);
}
}

InferenceResults infer = model.infer(fields, inferenceConfig);
if (infer instanceof WarningInferenceResults) {
String message = ((WarningInferenceResults) infer).getWarning();
logger.warn("inference error: " + message);
throw new ElasticsearchException(message);
} else {
SingleValueInferenceResults regressionResult = (SingleValueInferenceResults) infer;

float combinedScore = scoreModeSettings.scoreMode.combine(
hit.score * scoreModeSettings.queryWeight,
regressionResult.value().floatValue() * scoreModeSettings.modelWeight
);

sortedHits[hitIndex] = new ScoreDoc(hit.doc, combinedScore);
}
}

return new TopDocs(topDocs.totalHits, sortedHits);
}

@Override
public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext, Explanation sourceExplanation) {
return Explanation.match(1.0, "because");
}
}
Loading