Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch out how to pass a service to a FetchSubPhase. #7

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -129,6 +131,8 @@
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields;
import org.elasticsearch.xpack.core.ml.notifications.NotificationsIndex;
Expand Down Expand Up @@ -208,6 +212,7 @@
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.search.InferencePhase;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
Expand Down Expand Up @@ -318,7 +323,7 @@

import static java.util.Collections.emptyList;

public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
Expand Down Expand Up @@ -414,6 +419,7 @@ public Set<DiscoveryNodeRole> getRoles() {
private final SetOnce<DataFrameAnalyticsManager> dataFrameAnalyticsManager = new SetOnce<>();
private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();

public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
Expand Down Expand Up @@ -628,6 +634,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
clusterService,
xContentRegistry,
settings);
this.modelLoadingService.set(modelLoadingService);

// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
Expand Down Expand Up @@ -886,6 +893,15 @@ public Map<String, AnalysisProvider<TokenizerFactory>> getTokenizers() {
return Collections.singletonMap(MlClassicTokenizer.NAME, MlClassicTokenizerFactory::new);
}

@Override
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
// TODO: add a way to specify the model ID and necessary config.
// TODO: is the full inference config really needed?
String modelId = "model-id";
InferenceConfig config = new RegressionConfig("ignored");
return List.of(new InferencePhase(modelId, config, this.modelLoadingService));
}

@Override
public UnaryOperator<Map<String, IndexTemplateMetaData>> getIndexTemplateMetaDataUpgrader() {
return UnaryOperator.identity();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.elasticsearch.xpack.ml.inference.search;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;

import java.io.IOException;
import java.util.Map;

/**
* A very rough sketch of a fetch sub phase that performs inference on each search hit,
* then augments the hit with the result.
*/
public class InferencePhase implements FetchSubPhase {
private final String modelId;
private final InferenceConfig config;
private final SetOnce<ModelLoadingService> modelLoadingService;

public InferencePhase(String modelId,
InferenceConfig config,
SetOnce<ModelLoadingService> modelLoadingService) {
this.modelId = modelId;
this.config = config;
this.modelLoadingService = modelLoadingService;
}

@Override
public void hitsExecute(SearchContext searchContext, SearchHit[] hits) throws IOException {
SetOnce<Model> model = new SetOnce<>();
modelLoadingService.get().getModel(modelId, ActionListener.wrap(
model::set,
m -> {throw new RuntimeException();}));

for (SearchHit hit : hits) {
// TODO: get fields through context.lookup() (or from the search hit?)
Map<String, Object> document = Map.of();

SetOnce<InferenceResults> result = new SetOnce<>();
model.get().infer(document, config, ActionListener.wrap(
result::set,
m -> {throw new RuntimeException();}));

// TODO: add inference result to hit.
}
}
}