diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 67571d5b09a3f..6d70caf56b915 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -47,10 +47,12 @@ import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.script.ScriptService; +import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; @@ -129,6 +131,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; import org.elasticsearch.xpack.core.ml.notifications.NotificationsIndex; @@ -208,6 +212,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.search.InferencePhase; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -318,7 +323,7 @@ import static java.util.Collections.emptyList; -public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin, SearchPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -414,6 +419,7 @@ public Set getRoles() { private final SetOnce dataFrameAnalyticsManager = new SetOnce<>(); private final SetOnce dataFrameAnalyticsAuditor = new SetOnce<>(); private final SetOnce memoryTracker = new SetOnce<>(); + private final SetOnce modelLoadingService = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -628,6 +634,7 @@ public Collection createComponents(Client client, ClusterService cluster clusterService, xContentRegistry, settings); + this.modelLoadingService.set(modelLoadingService); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, @@ -886,6 +893,15 @@ public Map> getTokenizers() { return Collections.singletonMap(MlClassicTokenizer.NAME, MlClassicTokenizerFactory::new); } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + // TODO: add a way to specify the model ID and necessary config. + // TODO: is the full inference config really needed? + String modelId = "model-id"; + InferenceConfig config = new RegressionConfig("ignored"); + return List.of(new InferencePhase(modelId, config, this.modelLoadingService)); + } + @Override public UnaryOperator> getIndexTemplateMetaDataUpgrader() { return UnaryOperator.identity(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferencePhase.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferencePhase.java new file mode 100644 index 0000000000000..33e69cadd086b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/search/InferencePhase.java @@ -0,0 +1,52 @@ +package org.elasticsearch.xpack.ml.inference.search; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; + +import java.io.IOException; +import java.util.Map; + +/** + * A very rough sketch of a fetch sub phase that performs inference on each search hit, + * then augments the hit with the result. + */ +public class InferencePhase implements FetchSubPhase { + private final String modelId; + private final InferenceConfig config; + private final SetOnce modelLoadingService; + + public InferencePhase(String modelId, + InferenceConfig config, + SetOnce modelLoadingService) { + this.modelId = modelId; + this.config = config; + this.modelLoadingService = modelLoadingService; + } + + @Override + public void hitsExecute(SearchContext searchContext, SearchHit[] hits) throws IOException { + SetOnce model = new SetOnce<>(); + modelLoadingService.get().getModel(modelId, ActionListener.wrap( + model::set, + m -> {throw new RuntimeException();})); + + for (SearchHit hit : hits) { + // TODO: get fields through context.lookup() (or from the search hit?) + Map document = Map.of(); + + SetOnce result = new SetOnce<>(); + model.get().infer(document, config, ActionListener.wrap( + result::set, + m -> {throw new RuntimeException();})); + + // TODO: add inference result to hit. + } + } +}