diff --git a/docs/changelog/102713.yaml b/docs/changelog/102713.yaml new file mode 100644 index 0000000000000..278d7d4ffb129 --- /dev/null +++ b/docs/changelog/102713.yaml @@ -0,0 +1,5 @@ +pr: 102713 +summary: "ESQL: Add `profile` option" +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/102731.yaml b/docs/changelog/102731.yaml new file mode 100644 index 0000000000000..a12e04bfab078 --- /dev/null +++ b/docs/changelog/102731.yaml @@ -0,0 +1,5 @@ +pr: 102731 +summary: Add internal inference action for ml models an services +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/102806.yaml b/docs/changelog/102806.yaml new file mode 100644 index 0000000000000..faa971ec1d879 --- /dev/null +++ b/docs/changelog/102806.yaml @@ -0,0 +1,5 @@ +pr: 102806 +summary: Support for GET all models and by task type in the `_inference` API +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/102832.yaml b/docs/changelog/102832.yaml new file mode 100644 index 0000000000000..7daf22263b2e9 --- /dev/null +++ b/docs/changelog/102832.yaml @@ -0,0 +1,5 @@ +pr: 102832 +summary: Disable concurrency for sampler and diversified sampler +area: Aggregations +type: enhancement +issues: [] diff --git a/docs/changelog/102844.yaml b/docs/changelog/102844.yaml new file mode 100644 index 0000000000000..d05547c3aa9da --- /dev/null +++ b/docs/changelog/102844.yaml @@ -0,0 +1,5 @@ +pr: 102844 +summary: Skip global ordinals loading if query does not match after rewrite +area: Aggregations +type: bug +issues: [] diff --git a/docs/changelog/102848.yaml b/docs/changelog/102848.yaml new file mode 100644 index 0000000000000..971d91a878579 --- /dev/null +++ b/docs/changelog/102848.yaml @@ -0,0 +1,5 @@ +pr: 102848 +summary: Decref `SharedBytes.IO` after read is done not before +area: Snapshot/Restore +type: bug +issues: [] diff --git a/docs/reference/mapping/types/geo-shape.asciidoc b/docs/reference/mapping/types/geo-shape.asciidoc index 37ef340733932..628f764c04fe9 100644 --- a/docs/reference/mapping/types/geo-shape.asciidoc +++ b/docs/reference/mapping/types/geo-shape.asciidoc @@ -30,6 +30,15 @@ The `geo_shape` mapping maps GeoJSON or WKT geometry objects to the `geo_shape` type. To enable it, users must explicitly map fields to the `geo_shape` type. +[NOTE] +============================================= +In https://datatracker.ietf.org/doc/html/rfc7946[GeoJSON] +and https://www.ogc.org/standard/sfa/[WKT], and therefore Elasticsearch, +the correct *coordinate order is longitude, latitude (X, Y)* within coordinate +arrays. This differs from many Geospatial APIs (e.g., Google Maps) that generally +use the colloquial latitude, longitude (Y, X). +============================================= + [cols="<,<,<",options="header",] |======================================================================= |Option |Description| Default @@ -142,11 +151,6 @@ specifying only the top left and bottom right points. ============================================= For all types, both the inner `type` and `coordinates` fields are required. - -In GeoJSON and WKT, and therefore Elasticsearch, the correct *coordinate -order is longitude, latitude (X, Y)* within coordinate arrays. This -differs from many Geospatial APIs (e.g., Google Maps) that generally -use the colloquial latitude, longitude (Y, X). ============================================= [[geo-point-type]] diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index ed7ae1b5b5638..8c5022cea289d 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -1,5 +1,5 @@ - + false false @@ -1401,19 +1401,19 @@ - - - + + + - - - + + + - - - + + + diff --git a/modules/apm/build.gradle b/modules/apm/build.gradle index 13f1ac4a4cd3e..4c822e44da6f6 100644 --- a/modules/apm/build.gradle +++ b/modules/apm/build.gradle @@ -12,12 +12,13 @@ esplugin { classname 'org.elasticsearch.telemetry.apm.APM' } -def otelVersion = '1.17.0' +def otelVersion = '1.31.0' +def otelSemconvVersion = '1.21.0-alpha' dependencies { implementation "io.opentelemetry:opentelemetry-api:${otelVersion}" implementation "io.opentelemetry:opentelemetry-context:${otelVersion}" - implementation "io.opentelemetry:opentelemetry-semconv:${otelVersion}-alpha" + implementation "io.opentelemetry:opentelemetry-semconv:${otelSemconvVersion}" runtimeOnly "co.elastic.apm:elastic-apm-agent:1.44.0" } diff --git a/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/10_basic.yml b/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/10_basic.yml index 6496930764ab8..b1e0cf8ed7d90 100644 --- a/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/10_basic.yml +++ b/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/10_basic.yml @@ -210,8 +210,10 @@ setup: --- "Create data stream with failure store": - skip: - version: " - 8.10.99" - reason: "data stream failure stores only creatable in 8.11+" + version: all + reason: "AwaitsFix https://github.com/elastic/elasticsearch/issues/102873" +# version: " - 8.10.99" +# reason: "data stream failure stores only creatable in 8.11+" - do: allowed_warnings: diff --git a/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/30_auto_create_data_stream.yml b/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/30_auto_create_data_stream.yml index 303a584555f8f..a7d8476ee2dcf 100644 --- a/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/30_auto_create_data_stream.yml +++ b/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/30_auto_create_data_stream.yml @@ -50,8 +50,10 @@ --- "Put index template with failure store": - skip: - version: " - 8.10.99" - reason: "data stream failure stores only creatable in 8.11+" + version: all + reason: "AwaitsFix https://github.com/elastic/elasticsearch/issues/102873" +# version: " - 8.10.99" +# reason: "data stream failure stores only creatable in 8.11+" features: allowed_warnings - do: diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java index 73669ccacdbc6..49ad8302605cf 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomMustacheFactory.java @@ -47,6 +47,7 @@ public final class CustomMustacheFactory extends DefaultMustacheFactory { static final String X_WWW_FORM_URLENCODED_MEDIA_TYPE = "application/x-www-form-urlencoded"; private static final String DEFAULT_MEDIA_TYPE = JSON_MEDIA_TYPE; + private static final boolean DEFAULT_DETECT_MISSING_PARAMS = false; private static final Map> ENCODERS = Map.of( V7_JSON_MEDIA_TYPE_WITH_CHARSET, @@ -63,14 +64,30 @@ public final class CustomMustacheFactory extends DefaultMustacheFactory { private final Encoder encoder; + /** + * Initializes a CustomMustacheFactory object with a specified mediaType. + * + * @deprecated Use {@link #builder()} instead to retrieve a {@link Builder} object that can be used to create a factory. + */ + @Deprecated public CustomMustacheFactory(String mediaType) { - super(); - setObjectHandler(new CustomReflectionObjectHandler()); - this.encoder = createEncoder(mediaType); + this(mediaType, DEFAULT_DETECT_MISSING_PARAMS); } + /** + * Default constructor for the factory. + * + * @deprecated Use {@link #builder()} instead to retrieve a {@link Builder} object that can be used to create a factory. + */ + @Deprecated public CustomMustacheFactory() { - this(DEFAULT_MEDIA_TYPE); + this(DEFAULT_MEDIA_TYPE, DEFAULT_DETECT_MISSING_PARAMS); + } + + private CustomMustacheFactory(String mediaType, boolean detectMissingParams) { + super(); + setObjectHandler(new CustomReflectionObjectHandler(detectMissingParams)); + this.encoder = createEncoder(mediaType); } @Override @@ -95,6 +112,10 @@ public MustacheVisitor createMustacheVisitor() { return new CustomMustacheVisitor(this); } + public static Builder builder() { + return new Builder(); + } + class CustomMustacheVisitor extends DefaultMustacheVisitor { CustomMustacheVisitor(DefaultMustacheFactory df) { @@ -360,4 +381,34 @@ public void encode(String s, Writer writer) throws IOException { writer.write(URLEncoder.encode(s, StandardCharsets.UTF_8)); } } + + /** + * Build a new {@link CustomMustacheFactory} object. + */ + public static class Builder { + private String mediaType = DEFAULT_MEDIA_TYPE; + private boolean detectMissingParams = DEFAULT_DETECT_MISSING_PARAMS; + + private Builder() {} + + public Builder mediaType(String mediaType) { + this.mediaType = mediaType; + return this; + } + + /** + * Sets the behavior for handling missing parameters during template execution. + * + * @param detectMissingParams If true, an exception is thrown when executing the template with missing parameters. + * If false, the template gracefully handles missing parameters without throwing an exception. + */ + public Builder detectMissingParams(boolean detectMissingParams) { + this.detectMissingParams = detectMissingParams; + return this; + } + + public CustomMustacheFactory build() { + return new CustomMustacheFactory(mediaType, detectMissingParams); + } + } } diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java index c1e87fdc0970e..491ec6c851342 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java @@ -8,7 +8,15 @@ package org.elasticsearch.script.mustache; +import com.github.mustachejava.Binding; +import com.github.mustachejava.Code; +import com.github.mustachejava.ObjectHandler; +import com.github.mustachejava.TemplateContext; +import com.github.mustachejava.codes.ValueCode; +import com.github.mustachejava.reflect.GuardedBinding; +import com.github.mustachejava.reflect.MissingWrapper; import com.github.mustachejava.reflect.ReflectionObjectHandler; +import com.github.mustachejava.util.Wrapper; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.util.Maps; @@ -19,10 +27,16 @@ import java.util.AbstractMap; import java.util.Collection; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Set; final class CustomReflectionObjectHandler extends ReflectionObjectHandler { + private final boolean detectMissingParams; + + CustomReflectionObjectHandler(boolean detectMissingParams) { + this.detectMissingParams = detectMissingParams; + } @Override public Object coerce(Object object) { @@ -41,6 +55,11 @@ public Object coerce(Object object) { } } + @Override + public Binding createBinding(String name, TemplateContext tc, Code code) { + return detectMissingParams ? new DetectMissingParamsGuardedBinding(this, name, tc, code) : super.createBinding(name, tc, code); + } + @Override @SuppressWarnings("rawtypes") protected AccessibleObject findMember(Class sClass, String name) { @@ -59,6 +78,23 @@ protected AccessibleObject findMember(Class sClass, String name) { return null; } + static class DetectMissingParamsGuardedBinding extends GuardedBinding { + private final Code code; + + DetectMissingParamsGuardedBinding(ObjectHandler oh, String name, TemplateContext tc, Code code) { + super(oh, name, tc, code); + this.code = code; + } + + protected synchronized Wrapper getWrapper(String name, List scopes) { + Wrapper wrapper = super.getWrapper(name, scopes); + if (wrapper instanceof MissingWrapper && code instanceof ValueCode) { + throw new MustacheInvalidParameterException("Parameter [" + name + "] is missing"); + } + return wrapper; + } + } + static final class ArrayMap extends AbstractMap implements Iterable { private final Object array; diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheInvalidParameterException.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheInvalidParameterException.java new file mode 100644 index 0000000000000..9aaf8cdae89ad --- /dev/null +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheInvalidParameterException.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.mustache; + +import com.github.mustachejava.MustacheException; + +public class MustacheInvalidParameterException extends MustacheException { + MustacheInvalidParameterException(String message) { + super(message, null, null); + } +} diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java index c6f60c48c4ab4..61102de0ab5a4 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java @@ -38,6 +38,10 @@ * {@link Mustache} object can then be re-used for subsequent executions. */ public final class MustacheScriptEngine implements ScriptEngine { + /** + * Compiler option to enable detection of missing parameters. + */ + public static final String DETECT_MISSING_PARAMS_OPTION = "detect_missing_params"; private static final Logger logger = LogManager.getLogger(MustacheScriptEngine.class); public static final String NAME = "mustache"; @@ -72,10 +76,20 @@ public Set> getSupportedContexts() { } private static CustomMustacheFactory createMustacheFactory(Map options) { - if (options == null || options.isEmpty() || options.containsKey(Script.CONTENT_TYPE_OPTION) == false) { - return new CustomMustacheFactory(); + CustomMustacheFactory.Builder builder = CustomMustacheFactory.builder(); + if (options == null || options.isEmpty()) { + return builder.build(); + } + + if (options.containsKey(Script.CONTENT_TYPE_OPTION)) { + builder.mediaType(options.get(Script.CONTENT_TYPE_OPTION)); + } + + if (options.containsKey(DETECT_MISSING_PARAMS_OPTION)) { + builder.detectMissingParams(Boolean.valueOf(options.get(DETECT_MISSING_PARAMS_OPTION))); } - return new CustomMustacheFactory(options.get(Script.CONTENT_TYPE_OPTION)); + + return builder.build(); } @Override @@ -107,10 +121,17 @@ public String execute() { try { template.execute(writer, params); } catch (Exception e) { - logger.error(() -> format("Error running %s", template), e); + if (shouldLogException(e)) { + logger.error(() -> format("Error running %s", template), e); + } throw new GeneralScriptException("Error running " + template, e); } return writer.toString(); } + + public boolean shouldLogException(Throwable e) { + return e.getCause() != null && e.getCause() instanceof MustacheInvalidParameterException == false; + } } + } diff --git a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java index 0d3e881e54a56..4896584d7aadf 100644 --- a/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java +++ b/modules/lang-mustache/src/test/java/org/elasticsearch/script/mustache/MustacheScriptEngineTests.java @@ -9,6 +9,7 @@ import com.github.mustachejava.MustacheFactory; +import org.elasticsearch.script.GeneralScriptException; import org.elasticsearch.script.Script; import org.elasticsearch.script.TemplateScript; import org.elasticsearch.test.ESTestCase; @@ -18,10 +19,13 @@ import java.io.IOException; import java.io.StringWriter; +import java.util.Collections; import java.util.List; import java.util.Map; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.startsWith; /** * Mustache based templating test @@ -33,7 +37,7 @@ public class MustacheScriptEngineTests extends ESTestCase { @Before public void setup() { qe = new MustacheScriptEngine(); - factory = new CustomMustacheFactory(); + factory = CustomMustacheFactory.builder().build(); } public void testSimpleParameterReplace() { @@ -196,6 +200,55 @@ public void testSimple() throws IOException { assertThat(TemplateScript.execute(), equalTo("{\"match_all\":{}}")); } + public void testDetectMissingParam() { + Map scriptOptions = Map.ofEntries(Map.entry(MustacheScriptEngine.DETECT_MISSING_PARAMS_OPTION, "true")); + + // fails when a param is missing and the DETECT_MISSING_PARAMS_OPTION option is set to true. + { + String source = "{\"match\": { \"field\": \"{{query_string}}\" }"; + TemplateScript.Factory compiled = qe.compile(null, source, TemplateScript.CONTEXT, scriptOptions); + Map params = Collections.emptyMap(); + GeneralScriptException e = expectThrows(GeneralScriptException.class, () -> compiled.newInstance(params).execute()); + assertThat(e.getRootCause(), instanceOf(MustacheInvalidParameterException.class)); + assertThat(e.getRootCause().getMessage(), startsWith("Parameter [query_string] is missing")); + } + + // fails when params is null and the DETECT_MISSING_PARAMS_OPTION option is set to true. + { + String source = "{\"match\": { \"field\": \"{{query_string}}\" }"; + TemplateScript.Factory compiled = qe.compile(null, source, TemplateScript.CONTEXT, scriptOptions); + GeneralScriptException e = expectThrows(GeneralScriptException.class, () -> compiled.newInstance(null).execute()); + assertThat(e.getRootCause(), instanceOf(MustacheInvalidParameterException.class)); + assertThat(e.getRootCause().getMessage(), startsWith("Parameter [query_string] is missing")); + } + + // works as expected when params are specified and the DETECT_MISSING_PARAMS_OPTION option is set to true + { + String source = "{\"match\": { \"field\": \"{{query_string}}\" }"; + TemplateScript.Factory compiled = qe.compile(null, source, TemplateScript.CONTEXT, scriptOptions); + Map params = Map.ofEntries(Map.entry("query_string", "foo")); + assertThat(compiled.newInstance(params).execute(), equalTo("{\"match\": { \"field\": \"foo\" }")); + } + + // do not throw when using a missing param in the conditional when DETECT_MISSING_PARAMS_OPTION option is set to true + { + String source = "{\"match\": { \"field\": \"{{#query_string}}{{.}}{{/query_string}}\" }"; + TemplateScript.Factory compiled = qe.compile(null, source, TemplateScript.CONTEXT, scriptOptions); + Map params = Map.of(); + assertThat(compiled.newInstance(params).execute(), equalTo("{\"match\": { \"field\": \"\" }")); + } + } + + public void testMissingParam() { + Map scriptOptions = Collections.emptyMap(); + String source = "{\"match\": { \"field\": \"{{query_string}}\" }"; + TemplateScript.Factory compiled = qe.compile(null, source, TemplateScript.CONTEXT, scriptOptions); + + // When the DETECT_MISSING_PARAMS_OPTION is not specified, missing variable is replaced with an empty string. + assertThat(compiled.newInstance(Collections.emptyMap()).execute(), equalTo("{\"match\": { \"field\": \"\" }")); + assertThat(compiled.newInstance(null).execute(), equalTo("{\"match\": { \"field\": \"\" }")); + } + public void testParseTemplateAsSingleStringWithConditionalClause() throws IOException { String templateString = """ { diff --git a/modules/repository-url/build.gradle b/modules/repository-url/build.gradle index 7b671802f3a2a..2850aee68a2fb 100644 --- a/modules/repository-url/build.gradle +++ b/modules/repository-url/build.gradle @@ -8,12 +8,9 @@ import org.elasticsearch.gradle.PropertyNormalization -apply plugin: 'elasticsearch.legacy-yaml-rest-test' -apply plugin: 'elasticsearch.legacy-yaml-rest-compat-test' +apply plugin: 'elasticsearch.internal-yaml-rest-test' +apply plugin: 'elasticsearch.yaml-rest-compat-test' apply plugin: 'elasticsearch.internal-cluster-test' -apply plugin: 'elasticsearch.test.fixtures' - -final Project fixture = project(':test:fixtures:url-fixture') esplugin { description 'Module for URL repository' @@ -32,6 +29,8 @@ dependencies { api "commons-logging:commons-logging:${versions.commonslogging}" api "commons-codec:commons-codec:${versions.commonscodec}" api "org.apache.logging.log4j:log4j-1.2-api:${versions.log4j}" + yamlRestTestImplementation project(':test:fixtures:url-fixture') + internalClusterTestImplementation project(':test:fixtures:url-fixture') } tasks.named("thirdPartyAudit").configure { @@ -45,15 +44,7 @@ tasks.named("thirdPartyAudit").configure { ) } -testFixtures.useFixture(fixture.path, 'url-fixture') - -def fixtureAddress = { fixtureName -> - int ephemeralPort = fixture.postProcessFixture.ext."test.fixtures.${fixtureName}.tcp.80" - assert ephemeralPort > 0 - 'http://127.0.0.1:' + ephemeralPort -} - -File repositoryDir = fixture.fsRepositoryDir as File +//File repositoryDir = fixture.fsRepositoryDir as File testClusters.configureEach { // repositoryDir is used by a FS repository to create snapshots diff --git a/modules/repository-url/src/yamlRestTest/java/org/elasticsearch/repositories/url/RepositoryURLClientYamlTestSuiteIT.java b/modules/repository-url/src/yamlRestTest/java/org/elasticsearch/repositories/url/RepositoryURLClientYamlTestSuiteIT.java index 0958276656a81..a5b1a48f94ac9 100644 --- a/modules/repository-url/src/yamlRestTest/java/org/elasticsearch/repositories/url/RepositoryURLClientYamlTestSuiteIT.java +++ b/modules/repository-url/src/yamlRestTest/java/org/elasticsearch/repositories/url/RepositoryURLClientYamlTestSuiteIT.java @@ -8,6 +8,8 @@ package org.elasticsearch.repositories.url; +import fixture.url.URLFixture; + import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; @@ -22,11 +24,15 @@ import org.elasticsearch.core.PathUtils; import org.elasticsearch.repositories.fs.FsRepository; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.junit.Before; +import org.junit.ClassRule; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; import java.io.IOException; import java.net.InetAddress; @@ -42,6 +48,22 @@ public class RepositoryURLClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { + public static final URLFixture urlFixture = new URLFixture(); + + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .module("repository-url") + .setting("path.repo", urlFixture::getRepositoryDir) + .setting("repositories.url.allowed_urls", () -> "http://snapshot.test*, " + urlFixture.getAddress()) + .build(); + + @ClassRule + public static TestRule ruleChain = RuleChain.outerRule(urlFixture).around(cluster); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + public RepositoryURLClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { super(testCandidate); } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/connector.update_filtering.json b/rest-api-spec/src/main/resources/rest-api-spec/api/connector.update_filtering.json new file mode 100644 index 0000000000000..6923dc88006e3 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/connector.update_filtering.json @@ -0,0 +1,39 @@ +{ + "connector.update_filtering": { + "documentation": { + "url": "https://www.elastic.co/guide/en/enterprise-search/current/connectors.html", + "description": "Updates the filtering field in the connector document." + }, + "stability": "experimental", + "visibility": "feature_flag", + "feature_flag": "es.connector_api_feature_flag_enabled", + "headers": { + "accept": [ + "application/json" + ], + "content_type": [ + "application/json" + ] + }, + "url": { + "paths": [ + { + "path": "/_connector/{connector_id}/_filtering", + "methods": [ + "PUT" + ], + "parts": { + "connector_id": { + "type": "string", + "description": "The unique identifier of the connector to be updated." + } + } + } + ] + }, + "body": { + "description": "A list of connector filtering configurations.", + "required": true + } + } +} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/connector.update_pipeline.json b/rest-api-spec/src/main/resources/rest-api-spec/api/connector.update_pipeline.json new file mode 100644 index 0000000000000..2bd1acf7d28a6 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/connector.update_pipeline.json @@ -0,0 +1,39 @@ +{ + "connector.update_pipeline": { + "documentation": { + "url": "https://www.elastic.co/guide/en/enterprise-search/current/connectors.html", + "description": "Updates the pipeline field in the connector document." + }, + "stability": "experimental", + "visibility": "feature_flag", + "feature_flag": "es.connector_api_feature_flag_enabled", + "headers": { + "accept": [ + "application/json" + ], + "content_type": [ + "application/json" + ] + }, + "url": { + "paths": [ + { + "path": "/_connector/{connector_id}/_pipeline", + "methods": [ + "PUT" + ], + "parts": { + "connector_id": { + "type": "string", + "description": "The unique identifier of the connector to be updated." + } + } + } + ] + }, + "body": { + "description": "An object with connector ingest pipeline configuration.", + "required": true + } + } +} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/connector_sync_job.cancel.json b/rest-api-spec/src/main/resources/rest-api-spec/api/connector_sync_job.cancel.json new file mode 100644 index 0000000000000..883dd54bcb89b --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/connector_sync_job.cancel.json @@ -0,0 +1,32 @@ +{ + "connector_sync_job.cancel": { + "documentation": { + "url": "https://www.elastic.co/guide/en/enterprise-search/current/connectors.html", + "description": "Cancels a connector sync job." + }, + "stability": "experimental", + "visibility": "feature_flag", + "feature_flag": "es.connector_api_feature_flag_enabled", + "headers": { + "accept": [ + "application/json" + ] + }, + "url": { + "paths": [ + { + "path": "/_connector/_sync_job/{connector_sync_job_id}/_cancel", + "methods": [ + "PUT" + ], + "parts": { + "connector_sync_job_id": { + "type": "string", + "description": "The unique identifier of the connector sync job to be canceled" + } + } + } + ] + } + } +} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/connector_sync_job.check_in.json b/rest-api-spec/src/main/resources/rest-api-spec/api/connector_sync_job.check_in.json new file mode 100644 index 0000000000000..6c406a3a3d2c1 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/connector_sync_job.check_in.json @@ -0,0 +1,32 @@ +{ + "connector_sync_job.check_in": { + "documentation": { + "url": "https://www.elastic.co/guide/en/enterprise-search/current/connectors.html", + "description": "Checks in a connector sync job (refreshes 'last_seen')." + }, + "stability": "experimental", + "visibility": "feature_flag", + "feature_flag": "es.connector_api_feature_flag_enabled", + "headers": { + "accept": [ + "application/json" + ] + }, + "url": { + "paths": [ + { + "path": "/_connector/_sync_job/{connector_sync_job_id}/_check_in", + "methods": [ + "PUT" + ], + "parts": { + "connector_sync_job_id": { + "type": "string", + "description": "The unique identifier of the connector sync job to be checked in" + } + } + } + ] + } + } +} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java index ad610954e86b6..19dfe598b5318 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java @@ -8,7 +8,6 @@ package org.elasticsearch.search; -import org.apache.lucene.tests.util.LuceneTestCase; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.search.MultiSearchResponse; @@ -50,8 +49,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.notNullValue; -@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102257") +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST) public class SearchCancellationIT extends AbstractSearchCancellationTestCase { @Override @@ -206,7 +204,7 @@ public void testCancellationOfScrollSearchesOnFollowupRequests() throws Exceptio public void testCancelMultiSearch() throws Exception { List plugins = initBlockFactory(); indexTestData(); - ActionFuture msearchResponse = client().prepareMultiSearch() + ActionFuture multiSearchResponse = client().prepareMultiSearch() .add( prepareSearch("test").addScriptField( "test_field", @@ -214,18 +212,24 @@ public void testCancelMultiSearch() throws Exception { ) ) .execute(); - awaitForBlock(plugins); - cancelSearch(TransportMultiSearchAction.TYPE.name()); - disableBlocks(plugins); - for (MultiSearchResponse.Item item : msearchResponse.actionGet()) { - if (item.getFailure() != null) { - assertThat(ExceptionsHelper.unwrap(item.getFailure(), TaskCancelledException.class), notNullValue()); - } else { - assertFailures(item.getResponse()); - for (ShardSearchFailure shardFailure : item.getResponse().getShardFailures()) { - assertThat(ExceptionsHelper.unwrap(shardFailure.getCause(), TaskCancelledException.class), notNullValue()); + MultiSearchResponse response = null; + try { + awaitForBlock(plugins); + cancelSearch(TransportMultiSearchAction.TYPE.name()); + disableBlocks(plugins); + response = multiSearchResponse.actionGet(); + for (MultiSearchResponse.Item item : response) { + if (item.getFailure() != null) { + assertThat(ExceptionsHelper.unwrap(item.getFailure(), TaskCancelledException.class), notNullValue()); + } else { + assertFailures(item.getResponse()); + for (ShardSearchFailure shardFailure : item.getResponse().getShardFailures()) { + assertThat(ExceptionsHelper.unwrap(shardFailure.getCause(), TaskCancelledException.class), notNullValue()); + } } } + } finally { + if (response != null) response.decRef(); } } @@ -288,12 +292,11 @@ public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception assertTrue("All SearchShardTasks should then be cancelled", shardQueryTask.isCancelled()); } }, 30, TimeUnit.SECONDS); - shardTaskLatch.countDown(); // unblock the shardTasks, allowing the test to conclude. } finally { + shardTaskLatch.countDown(); // unblock the shardTasks, allowing the test to conclude. searchThread.join(); - for (ScriptedBlockPlugin plugin : plugins) { - plugin.setBeforeExecution(() -> {}); - } + plugins.forEach(plugin -> plugin.setBeforeExecution(() -> {})); + searchShardBlockingPlugins.forEach(plugin -> plugin.setRunOnNewReaderContext((ReaderContext c) -> {})); } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/DiversifiedSamplerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/DiversifiedSamplerIT.java index 3a313cec29402..5a58780a24817 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/DiversifiedSamplerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/DiversifiedSamplerIT.java @@ -40,7 +40,7 @@ @ESIntegTestCase.SuiteScopeTestCase public class DiversifiedSamplerIT extends ESIntegTestCase { - public static final int NUM_SHARDS = 2; + private static final int NUM_SHARDS = 1; public String randomExecutionHint() { return randomBoolean() ? null : randomFrom(SamplerAggregator.ExecutionMode.values()).toString(); @@ -83,8 +83,9 @@ public void setupSuiteScopeCluster() throws Exception { prepareIndex("idx_unmapped_author").setId("" + i) .setSource("name", parts[2], "genre", parts[8], "price", Float.parseFloat(parts[3])) .get(); + // frequent refresh makes it more likely that more segments are created, hence we may parallelize the search across slices + indicesAdmin().refresh(new RefreshRequest()).get(); } - indicesAdmin().refresh(new RefreshRequest("test")).get(); } public void testIssue10719() throws Exception { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SamplerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SamplerIT.java index 7f46856cdd594..00779ba9b256e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SamplerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/SamplerIT.java @@ -81,8 +81,9 @@ public void setupSuiteScopeCluster() throws Exception { prepareIndex("idx_unmapped_author").setId("" + i) .setSource("name", parts[2], "genre", parts[8], "price", Float.parseFloat(parts[3])) .get(); + // frequent refresh makes it more likely that more segments are created, hence we may parallelize the search across slices + indicesAdmin().refresh(new RefreshRequest()).get(); } - indicesAdmin().refresh(new RefreshRequest("test")).get(); } public void testIssue10719() throws Exception { diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 44f98305d2997..4627a3d907133 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -181,6 +181,7 @@ static TransportVersion def(int id) { public static final TransportVersion GET_API_KEY_INVALIDATION_TIME_ADDED = def(8_548_00_0); public static final TransportVersion ML_INFERENCE_GET_MULTIPLE_MODELS = def(8_549_00_0); public static final TransportVersion INFERENCE_SERVICE_RESULTS_ADDED = def(8_550_00_0); + public static final TransportVersion ESQL_PROFILE = def(8_551_00_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/action/ActionModule.java b/server/src/main/java/org/elasticsearch/action/ActionModule.java index e0f01405bcf0f..8e008dc57c81b 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionModule.java +++ b/server/src/main/java/org/elasticsearch/action/ActionModule.java @@ -266,6 +266,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.NamedRegistry; import org.elasticsearch.common.inject.AbstractModule; @@ -516,6 +517,7 @@ public ActionModule( SystemIndices systemIndices, Tracer tracer, ClusterService clusterService, + RerouteService rerouteService, List> reservedStateHandlers, RestExtension restExtension ) { @@ -562,7 +564,7 @@ public ActionModule( } else { restController = new RestController(restInterceptor, nodeClient, circuitBreakerService, usageService, tracer); } - reservedClusterStateService = new ReservedClusterStateService(clusterService, reservedStateHandlers); + reservedClusterStateService = new ReservedClusterStateService(clusterService, rerouteService, reservedStateHandlers); this.restExtension = restExtension; } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java index e3373ded94dc7..9f4c42a810563 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java @@ -45,6 +45,7 @@ public class TransportUpdateDesiredNodesAction extends TransportMasterNodeAction { private static final Logger logger = LogManager.getLogger(TransportUpdateDesiredNodesAction.class); + private final RerouteService rerouteService; private final FeatureService featureService; private final Consumer> desiredNodesValidator; private final MasterServiceTaskQueue taskQueue; @@ -53,6 +54,7 @@ public class TransportUpdateDesiredNodesAction extends TransportMasterNodeAction public TransportUpdateDesiredNodesAction( TransportService transportService, ClusterService clusterService, + RerouteService rerouteService, FeatureService featureService, ThreadPool threadPool, ActionFilters actionFilters, @@ -62,6 +64,7 @@ public TransportUpdateDesiredNodesAction( this( transportService, clusterService, + rerouteService, featureService, threadPool, actionFilters, @@ -74,6 +77,7 @@ public TransportUpdateDesiredNodesAction( TransportUpdateDesiredNodesAction( TransportService transportService, ClusterService clusterService, + RerouteService rerouteService, FeatureService featureService, ThreadPool threadPool, ActionFilters actionFilters, @@ -93,12 +97,13 @@ public TransportUpdateDesiredNodesAction( UpdateDesiredNodesResponse::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); + this.rerouteService = rerouteService; this.featureService = featureService; this.desiredNodesValidator = desiredNodesValidator; this.taskQueue = clusterService.createTaskQueue( "update-desired-nodes", Priority.URGENT, - new UpdateDesiredNodesExecutor(clusterService.getRerouteService(), allocationService) + new UpdateDesiredNodesExecutor(rerouteService, allocationService) ); } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/settings/TransportClusterUpdateSettingsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/settings/TransportClusterUpdateSettingsAction.java index da44265f87436..e4093486da39c 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/settings/TransportClusterUpdateSettingsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/settings/TransportClusterUpdateSettingsAction.java @@ -22,6 +22,7 @@ import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.Priority; @@ -48,12 +49,14 @@ public class TransportClusterUpdateSettingsAction extends TransportMasterNodeAct private static final Logger logger = LogManager.getLogger(TransportClusterUpdateSettingsAction.class); + private final RerouteService rerouteService; private final ClusterSettings clusterSettings; @Inject public TransportClusterUpdateSettingsAction( TransportService transportService, ClusterService clusterService, + RerouteService rerouteService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, @@ -71,6 +74,7 @@ public TransportClusterUpdateSettingsAction( ClusterUpdateSettingsResponse::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); + this.rerouteService = rerouteService; this.clusterSettings = clusterSettings; } @@ -191,7 +195,7 @@ private void reroute(final boolean updateSettingsAcked) { // the components (e.g. FilterAllocationDecider), so the changes made by the first call aren't visible to the components // until the ClusterStateListener instances have been invoked, but are visible after the first update task has been // completed. - clusterService.getRerouteService().reroute(REROUTE_TASK_SOURCE, Priority.URGENT, new ActionListener<>() { + rerouteService.reroute(REROUTE_TASK_SOURCE, Priority.URGENT, new ActionListener<>() { @Override public void onResponse(Void ignored) { listener.onResponse( diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java index 32c99f5baba85..3ad5e7fa43fe1 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java @@ -35,7 +35,6 @@ import org.elasticsearch.cluster.routing.allocation.command.AllocationCommands; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders; import org.elasticsearch.cluster.routing.allocation.decider.Decision; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.logging.ESLogMessage; @@ -382,8 +381,7 @@ public CommandsResult reroute( * state. Should be called after every change to the cluster that affects the routing table and/or the balance of shards. *

* This method is expensive in larger clusters. Wherever possible you should invoke this method asynchronously using - * {@link RerouteService#reroute} to batch up invocations rather than calling the method directly. The node's reroute service is - * typically obtained from {@link ClusterService#getRerouteService}. + * {@link RerouteService#reroute} to batch up invocations rather than calling the method directly. * * @return an updated cluster state, or the same instance that was passed as an argument if no changes were made. */ @@ -400,8 +398,7 @@ public ClusterState reroute(ClusterState clusterState, String reason, ActionList * state. Should be called after every change to the cluster that affects the routing table and/or the balance of shards. *

* This method is expensive in larger clusters. Wherever possible you should invoke this method asynchronously using - * {@link RerouteService#reroute} to batch up invocations rather than calling the method directly. The node's reroute service is - * typically obtained from {@link ClusterService#getRerouteService}. + * {@link RerouteService#reroute} to batch up invocations rather than calling the method directly. * * @return an updated cluster state, or the same instance that was passed as an argument if no changes were made. */ diff --git a/server/src/main/java/org/elasticsearch/cluster/service/ClusterService.java b/server/src/main/java/org/elasticsearch/cluster/service/ClusterService.java index 67b6d64775dff..5c14b2ee1cbdf 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/ClusterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/ClusterService.java @@ -19,7 +19,6 @@ import org.elasticsearch.cluster.NodeConnectionsService; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.OperationRouting; -import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.common.Priority; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.settings.ClusterSettings; @@ -31,8 +30,6 @@ import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.threadpool.ThreadPool; -import java.util.function.Supplier; - public class ClusterService extends AbstractLifecycleComponent { private final MasterService masterService; @@ -56,24 +53,11 @@ public class ClusterService extends AbstractLifecycleComponent { private final String nodeName; - private final Supplier rerouteService; - public ClusterService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool, TaskManager taskManager) { - this(settings, clusterSettings, threadPool, taskManager, () -> { throw new IllegalStateException("RerouteService not provided"); }); - } - - public ClusterService( - Settings settings, - ClusterSettings clusterSettings, - ThreadPool threadPool, - TaskManager taskManager, - Supplier rerouteService - ) { this( settings, clusterSettings, new MasterService(settings, clusterSettings, threadPool, taskManager), - rerouteService, new ClusterApplierService(Node.NODE_NAME_SETTING.get(settings), settings, clusterSettings, threadPool) ); } @@ -83,27 +67,10 @@ public ClusterService( ClusterSettings clusterSettings, MasterService masterService, ClusterApplierService clusterApplierService - ) { - this( - settings, - clusterSettings, - masterService, - () -> { throw new IllegalStateException("RerouteService not provided"); }, - clusterApplierService - ); - } - - public ClusterService( - Settings settings, - ClusterSettings clusterSettings, - MasterService masterService, - Supplier rerouteService, - ClusterApplierService clusterApplierService ) { this.settings = settings; this.nodeName = Node.NODE_NAME_SETTING.get(settings); this.masterService = masterService; - this.rerouteService = rerouteService; this.operationRouting = new OperationRouting(settings, clusterSettings); this.clusterSettings = clusterSettings; this.clusterName = ClusterName.CLUSTER_NAME_SETTING.get(settings); @@ -120,10 +87,6 @@ public synchronized void setNodeConnectionsService(NodeConnectionsService nodeCo clusterApplierService.setNodeConnectionsService(nodeConnectionsService); } - public RerouteService getRerouteService() { - return rerouteService.get(); - } - @Override protected synchronized void doStart() { clusterApplierService.start(); diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java index 5a49896cf1a36..4eaf9b5636623 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java @@ -75,11 +75,11 @@ public static Iterator field(String name, String value) { } /** - * Creates an Iterator to serialize a named field where the value is represented by a chunked ToXContext. + * Creates an Iterator to serialize a named field where the value is represented by a {@link ChunkedToXContentObject}. * Chunked equivalent for {@code XContentBuilder field(String name, ToXContent value)} * @param name name of the field - * @param value ChunkedToXContent value for this field (single value, object or array) - * @param params ToXContent params to propagate for XContent serialization + * @param value value for this field + * @param params params to propagate for XContent serialization * @return Iterator composing field name and value serialization */ public static Iterator field(String name, ChunkedToXContentObject value, ToXContent.Params params) { @@ -90,6 +90,22 @@ public static Iterator array(String name, Iterator array(String name, Iterator contents, ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startArray(name), + Iterators.flatMap(contents, c -> c.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + public static Iterator wrapWithObject(String name, Iterator iterator) { return Iterators.concat(startObject(name), iterator, endObject()); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 499cf5d5ca64f..2f83310ea2388 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; import java.io.Closeable; import java.util.List; @@ -18,6 +19,8 @@ public interface InferenceService extends Closeable { + default void init(Client client) {} + String name(); /** @@ -52,7 +55,20 @@ public interface InferenceService extends Closeable { * @param secrets Sensitive configuration options (e.g. api key) * @return The parsed {@link Model} */ - Model parsePersistedConfig(String modelId, TaskType taskType, Map config, Map secrets); + Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map config, Map secrets); + + /** + * Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. + * This function modifies {@code config map}, fields are removed from the map as they are read. + * + * If the map contains unrecognized configuration options, no error is thrown. + * + * @param modelId Model Id + * @param taskType The model task type + * @param config Configuration options + * @return The parsed {@link Model} + */ + Model parsePersistedConfig(String modelId, TaskType taskType, Map config); /** * Perform inference on the model. diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java new file mode 100644 index 0000000000000..4b42e8ca53854 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.client.internal.Client; + +import java.util.List; + +/** + * SPI extension that define inference services + */ +public interface InferenceServiceExtension { + + List getInferenceServiceFactories(); + + record InferenceServiceFactoryContext(Client client) {} + + interface Factory { + /** + * InferenceServices are created from the factory context + */ + InferenceService create(InferenceServiceFactoryContext context); + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index ac1439150f8ec..a0ed7bbd82b24 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -8,9 +8,9 @@ package org.elasticsearch.inference; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.plugins.InferenceServicePlugin; import java.io.IOException; import java.util.ArrayList; @@ -26,18 +26,18 @@ public class InferenceServiceRegistry extends AbstractLifecycleComponent { private final List namedWriteables = new ArrayList<>(); public InferenceServiceRegistry( - List inferenceServicePlugins, - InferenceServicePlugin.InferenceServiceFactoryContext factoryContext + List inferenceServicePlugins, + InferenceServiceExtension.InferenceServiceFactoryContext factoryContext ) { // TODO check names are unique services = inferenceServicePlugins.stream() .flatMap(r -> r.getInferenceServiceFactories().stream()) .map(factory -> factory.create(factoryContext)) .collect(Collectors.toMap(InferenceService::name, Function.identity())); + } - for (var plugin : inferenceServicePlugins) { - namedWriteables.addAll(plugin.getInferenceServiceNamedWriteables()); - } + public void init(Client client) { + services.values().forEach(s -> s.init(client)); } public Map getServices() { diff --git a/server/src/main/java/org/elasticsearch/inference/Model.java b/server/src/main/java/org/elasticsearch/inference/Model.java index eedb67a8111e5..02be39d8a653d 100644 --- a/server/src/main/java/org/elasticsearch/inference/Model.java +++ b/server/src/main/java/org/elasticsearch/inference/Model.java @@ -27,6 +27,14 @@ public Model(ModelConfigurations configurations) { this(configurations, new ModelSecrets()); } + public String getModelId() { + return configurations.getModelId(); + } + + public TaskType getTaskType() { + return configurations.getTaskType(); + } + /** * Returns the model's non-sensitive configurations (e.g. service name). */ diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index 9e96a7c4c52d0..5afedee873145 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -20,7 +20,13 @@ public enum TaskType implements Writeable { TEXT_EMBEDDING, - SPARSE_EMBEDDING; + SPARSE_EMBEDDING, + ANY { + @Override + public boolean isAnyOrSame(TaskType other) { + return true; + } + }; public static String NAME = "task_type"; @@ -37,6 +43,16 @@ public static TaskType fromStringOrStatusException(String name) { } } + /** + * Return true if the {@code other} is the {@link #ANY} type + * or the same as this. + * @param other The other + * @return True if same or any. + */ + public boolean isAnyOrSame(TaskType other) { + return other == TaskType.ANY || other == this; + } + @Override public String toString() { return name().toLowerCase(Locale.ROOT); diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 0623c3b196e45..19a1310ed86aa 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -121,7 +121,6 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; -import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -139,7 +138,6 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; -import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -516,13 +514,6 @@ private void createClientAndRegistries(Settings settings, ThreadPool threadPool, localNodeFactory = new Node.LocalNodeFactory(settings, nodeEnvironment.nodeId()); - InferenceServiceRegistry inferenceServiceRegistry = new InferenceServiceRegistry( - pluginsService.filterPlugins(InferenceServicePlugin.class).toList(), - new InferenceServicePlugin.InferenceServiceFactoryContext(client) - ); - resourcesToClose.add(inferenceServiceRegistry); - modules.bindToInstance(InferenceServiceRegistry.class, inferenceServiceRegistry); - namedWriteableRegistry = new NamedWriteableRegistry( Stream.of( NetworkModule.getNamedWriteables().stream(), @@ -530,8 +521,7 @@ private void createClientAndRegistries(Settings settings, ThreadPool threadPool, searchModule.getNamedWriteables().stream(), pluginsService.flatMap(Plugin::getNamedWriteables), ClusterModule.getNamedWriteables().stream(), - SystemIndexMigrationExecutor.getNamedWriteables().stream(), - inferenceServiceRegistry.getNamedWriteables().stream() + SystemIndexMigrationExecutor.getNamedWriteables().stream() ).flatMap(Function.identity()).toList() ); xContentRegistry = new NamedXContentRegistry( @@ -607,8 +597,7 @@ private void construct( telemetryProvider.getTracer() ); - final SetOnce rerouteServiceReference = new SetOnce<>(); - ClusterService clusterService = createClusterService(settingsModule, threadPool, taskManager, rerouteServiceReference::get); + ClusterService clusterService = createClusterService(settingsModule, threadPool, taskManager); clusterService.addStateApplier(scriptService); Supplier documentParsingObserverSupplier = getDocumentParsingObserverSupplier(); @@ -628,6 +617,7 @@ private void construct( SystemIndices systemIndices = createSystemIndices(settings); final SetOnce repositoriesServiceReference = new SetOnce<>(); + final SetOnce rerouteServiceReference = new SetOnce<>(); final ClusterInfoService clusterInfoService = serviceProvider.newClusterInfoService( pluginsService, settings, @@ -759,6 +749,7 @@ private void construct( record PluginServiceInstances( Client client, ClusterService clusterService, + RerouteService rerouteService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, @@ -777,6 +768,7 @@ record PluginServiceInstances( PluginServiceInstances pluginServices = new PluginServiceInstances( client, clusterService, + rerouteService, threadPool, createResourceWatcherService(settings, threadPool), scriptService, @@ -814,6 +806,7 @@ record PluginServiceInstances( systemIndices, telemetryProvider.getTracer(), clusterService, + rerouteService, buildReservedStateHandlers( settingsModule, clusterService, @@ -900,6 +893,7 @@ record PluginServiceInstances( SnapshotsService snapshotsService = new SnapshotsService( settings, clusterService, + rerouteService, clusterModule.getIndexNameExpressionResolver(), repositoryService, transportService, @@ -1074,18 +1068,12 @@ record PluginServiceInstances( postInjection(clusterModule, actionModule, clusterService, transportService, featureService); } - private ClusterService createClusterService( - SettingsModule settingsModule, - ThreadPool threadPool, - TaskManager taskManager, - Supplier rerouteService - ) { + private ClusterService createClusterService(SettingsModule settingsModule, ThreadPool threadPool, TaskManager taskManager) { ClusterService clusterService = new ClusterService( settingsModule.getSettings(), settingsModule.getClusterSettings(), threadPool, - taskManager, - rerouteService + taskManager ); resourcesToClose.add(clusterService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java deleted file mode 100644 index 2672a4b8fcbcf..0000000000000 --- a/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.plugins; - -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.inference.InferenceService; - -import java.util.List; - -/** - * InferenceServicePlugins implement an inference service - */ -public interface InferenceServicePlugin { - - List getInferenceServiceFactories(); - - record InferenceServiceFactoryContext(Client client) {} - - interface Factory { - /** - * InferenceServices are created from the factory context - */ - InferenceService create(InferenceServiceFactoryContext context); - } - - /** - * The named writables defined and used by each of the implemented - * InferenceServices. Each service should define named writables for - * - {@link org.elasticsearch.inference.TaskSettings} - * - {@link org.elasticsearch.inference.ServiceSettings} - * And optionally for {@link org.elasticsearch.inference.InferenceResults} - * if the service uses a new type of result. - * @return All named writables defined by the services - */ - List getInferenceServiceNamedWriteables(); -} diff --git a/server/src/main/java/org/elasticsearch/plugins/Plugin.java b/server/src/main/java/org/elasticsearch/plugins/Plugin.java index de9f8186865aa..12ad05e2bc710 100644 --- a/server/src/main/java/org/elasticsearch/plugins/Plugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/Plugin.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.routing.allocation.AllocationService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.component.LifecycleComponent; @@ -78,6 +79,11 @@ public interface PluginServices { */ ClusterService clusterService(); + /** + * A service to reroute shards to other nodes + */ + RerouteService rerouteService(); + /** * A service to allow retrieving an executor to run an async action */ diff --git a/server/src/main/java/org/elasticsearch/reservedstate/service/ReservedClusterStateService.java b/server/src/main/java/org/elasticsearch/reservedstate/service/ReservedClusterStateService.java index f6d5ab3ead6af..76c2007dc8d8e 100644 --- a/server/src/main/java/org/elasticsearch/reservedstate/service/ReservedClusterStateService.java +++ b/server/src/main/java/org/elasticsearch/reservedstate/service/ReservedClusterStateService.java @@ -16,6 +16,7 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.ReservedStateErrorMetadata; import org.elasticsearch.cluster.metadata.ReservedStateMetadata; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.common.Priority; @@ -83,12 +84,16 @@ public class ReservedClusterStateService { * @param clusterService for fetching and saving the modified state * @param handlerList a list of reserved state handlers, which we use to transform the state */ - public ReservedClusterStateService(ClusterService clusterService, List> handlerList) { + public ReservedClusterStateService( + ClusterService clusterService, + RerouteService rerouteService, + List> handlerList + ) { this.clusterService = clusterService; this.updateTaskQueue = clusterService.createTaskQueue( "reserved state update", Priority.URGENT, - new ReservedStateUpdateTaskExecutor(clusterService.getRerouteService()) + new ReservedStateUpdateTaskExecutor(rerouteService) ); this.errorTaskQueue = clusterService.createTaskQueue("reserved state error", Priority.URGENT, new ReservedStateErrorTaskExecutor()); this.handlers = handlerList.stream().collect(Collectors.toMap(ReservedClusterStateHandler::name, Function.identity())); diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 548e3fea9d91c..9e59bfda96d19 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -540,6 +540,8 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task, return; } } + // TODO: i think it makes sense to always do a canMatch here and + // return an empty response (not null response) in case canMatch is false? ensureAfterSeqNoRefreshed(shard, orig, () -> executeQueryPhase(orig, task), l); })); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedAggregationBuilder.java index e77b15e1ed1d4..0eecdc9e2a6e5 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedAggregationBuilder.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.util.Map; import java.util.Objects; +import java.util.function.ToLongFunction; public class DiversifiedAggregationBuilder extends ValuesSourceAggregationBuilder { public static final String NAME = "diversified_sampler"; @@ -189,4 +190,9 @@ public String getType() { public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ZERO; } + + @Override + public boolean supportsParallelCollection(ToLongFunction fieldCardinalityResolver) { + return false; + } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregationBuilder.java index 5c3208418df08..0f85e5e11064c 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregationBuilder.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.Map; import java.util.Objects; +import java.util.function.ToLongFunction; public class SamplerAggregationBuilder extends AbstractAggregationBuilder { public static final String NAME = "sampler"; @@ -141,4 +142,9 @@ public String getType() { public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ZERO; } + + @Override + public boolean supportsParallelCollection(ToLongFunction fieldCardinalityResolver) { + return false; + } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java index 2cadbd3d43494..f47e28bbc6dbd 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java @@ -9,7 +9,10 @@ package org.elasticsearch.search.aggregations.bucket.terms; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.MatchNoDocsQuery; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.index.query.QueryBuilder; @@ -23,6 +26,7 @@ import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.NonCollectingAggregator; import org.elasticsearch.search.aggregations.bucket.BucketUtils; +import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator.BucketCountThresholds; import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic; import org.elasticsearch.search.aggregations.support.AggregationContext; @@ -32,6 +36,7 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; +import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.xcontent.ParseField; import java.io.IOException; @@ -81,7 +86,7 @@ private static SignificantTermsAggregatorSupplier bytesSupplier() { if (executionHint != null) { execution = ExecutionMode.fromString(executionHint, deprecationLogger); } - if (valuesSourceConfig.hasOrdinals() == false) { + if (valuesSourceConfig.hasOrdinals() == false || matchNoDocs(context, parent)) { execution = ExecutionMode.MAP; } if (execution == null) { @@ -115,6 +120,30 @@ private static SignificantTermsAggregatorSupplier bytesSupplier() { }; } + /** + * Whether the aggregation will execute. If the main query matches no documents and parent aggregation isn't a global or terms + * aggregation with min_doc_count = 0, the the aggregator will not really execute. In those cases it doesn't make sense to load + * global ordinals. + *

+ * Some searches that will never match can still fall through and we endup running query that will produce no results. + * However even in that case we sometimes do expensive things like loading global ordinals. This method should prevent this. + * Note that if {@link org.elasticsearch.search.SearchService#executeQueryPhase(ShardSearchRequest, SearchShardTask, ActionListener)} + * always do a can match then we don't need this code here. + */ + static boolean matchNoDocs(AggregationContext context, Aggregator parent) { + if (context.query() instanceof MatchNoDocsQuery) { + while (parent != null) { + if (parent instanceof GlobalAggregator) { + return false; + } + parent = parent.parent(); + } + return true; + } else { + return false; + } + } + /** * This supplier is used for all fields that expect to be aggregated as a numeric value. * This includes floating points, and formatted types that use numerics internally for storage (date, boolean, etc) @@ -296,7 +325,6 @@ protected Aggregator doCreateInternal(Aggregator parent, CardinalityUpperBound c public enum ExecutionMode { MAP(new ParseField("map")) { - @Override Aggregator create( String name, @@ -335,7 +363,6 @@ Aggregator create( }, GLOBAL_ORDINALS(new ParseField("global_ordinals")) { - @Override Aggregator create( String name, diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java index e17cd828a24d0..68a4ffca22b51 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java @@ -45,6 +45,8 @@ import java.util.function.Function; import java.util.function.LongPredicate; +import static org.elasticsearch.search.aggregations.bucket.terms.SignificantTermsAggregatorFactory.matchNoDocs; + public class TermsAggregatorFactory extends ValuesSourceAggregatorFactory { static Boolean REMAP_GLOBAL_ORDS, COLLECT_SEGMENT_ORDS; @@ -107,7 +109,7 @@ private static TermsAggregatorSupplier bytesSupplier() { execution = ExecutionMode.fromString(executionHint); } // In some cases, using ordinals is just not supported: override it - if (valuesSource.hasOrdinals() == false) { + if (valuesSource.hasOrdinals() == false || matchNoDocs(context, parent)) { execution = ExecutionMode.MAP; } if (execution == null) { diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java index 499ac7022403e..3b872f550fc6f 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java @@ -51,6 +51,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.routing.RoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.service.ClusterService; @@ -153,6 +154,8 @@ public final class SnapshotsService extends AbstractLifecycleComponent implement private final ClusterService clusterService; + private final RerouteService rerouteService; + private final IndexNameExpressionResolver indexNameExpressionResolver; private final RepositoriesService repositoriesService; @@ -203,6 +206,7 @@ public final class SnapshotsService extends AbstractLifecycleComponent implement public SnapshotsService( Settings settings, ClusterService clusterService, + RerouteService rerouteService, IndexNameExpressionResolver indexNameExpressionResolver, RepositoriesService repositoriesService, TransportService transportService, @@ -210,6 +214,7 @@ public SnapshotsService( SystemIndices systemIndices ) { this.clusterService = clusterService; + this.rerouteService = rerouteService; this.indexNameExpressionResolver = indexNameExpressionResolver; this.repositoriesService = repositoriesService; this.threadPool = transportService.getThreadPool(); @@ -3712,7 +3717,7 @@ public ClusterState execute(BatchExecutionContext batchExecutionCo final ClusterState state = batchExecutionContext.initialState(); final SnapshotShardsUpdateContext shardsUpdateContext = new SnapshotShardsUpdateContext( batchExecutionContext, - () -> clusterService.getRerouteService().reroute("after shards snapshot update", Priority.NORMAL, ActionListener.noop()) + () -> rerouteService.reroute("after shards snapshot update", Priority.NORMAL, ActionListener.noop()) ); final SnapshotsInProgress initialSnapshots = SnapshotsInProgress.get(state); SnapshotsInProgress snapshotsInProgress = shardsUpdateContext.computeUpdatedState(); diff --git a/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java b/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java index abf79243b6a61..2263bfe78f218 100644 --- a/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java +++ b/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java @@ -353,6 +353,7 @@ public void testActionTransportException() throws IOException { assertEquals("[name?][" + transportAddress + "][ACTION BABY!] message?", ex.getMessage()); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102868") public void testSearchContextMissingException() throws IOException { ShardSearchContextId contextId = new ShardSearchContextId(UUIDs.randomBase64UUID(), randomLong()); TransportVersion version = TransportVersionUtils.randomVersion(random()); diff --git a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java index 4f72357f83325..a076537bb7351 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionModuleTests.java @@ -121,6 +121,7 @@ public void testSetupRestHandlerContainsKnownBuiltin() { null, null, mock(ClusterService.class), + null, List.of(), RestExtension.allowAll() ); @@ -181,6 +182,7 @@ public String getName() { null, null, mock(ClusterService.class), + null, List.of(), RestExtension.allowAll() ); @@ -234,6 +236,7 @@ public List getRestHandlers( null, null, mock(ClusterService.class), + null, List.of(), RestExtension.allowAll() ); @@ -282,6 +285,7 @@ public void test3rdPartyHandlerIsNotInstalled() { null, null, mock(ClusterService.class), + null, List.of(), RestExtension.allowAll() ) @@ -314,13 +318,14 @@ public void test3rdPartyRestControllerIsNotInstalled() { settingsModule.getClusterSettings(), settingsModule.getSettingsFilter(), threadPool, - Arrays.asList(secPlugin), + List.of(secPlugin), null, null, usageService, null, null, mock(ClusterService.class), + null, List.of(), RestExtension.allowAll() ) diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java index 4e2948eafc1d7..00f46d8c42bf0 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesActionTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.cluster.metadata.DesiredNodesTestCase; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.routing.allocation.AllocationService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.features.FeatureService; @@ -50,6 +51,7 @@ public void testWriteBlocks() { final TransportUpdateDesiredNodesAction action = new TransportUpdateDesiredNodesAction( transportService, mock(ClusterService.class), + mock(RerouteService.class), mock(FeatureService.class), threadPool, mock(ActionFilters.class), @@ -78,6 +80,7 @@ public void testNoBlocks() { final TransportUpdateDesiredNodesAction action = new TransportUpdateDesiredNodesAction( transportService, mock(ClusterService.class), + mock(RerouteService.class), mock(FeatureService.class), threadPool, mock(ActionFilters.class), diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/settings/ClusterUpdateSettingsRequestTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/settings/ClusterUpdateSettingsRequestTests.java index 6373b94ffb94a..a1d2ef33d85f3 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/settings/ClusterUpdateSettingsRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/settings/ClusterUpdateSettingsRequestTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.ClusterSettings; @@ -91,6 +92,7 @@ public void testOperatorHandler() throws IOException { TransportClusterUpdateSettingsAction action = new TransportClusterUpdateSettingsAction( transportService, mock(ClusterService.class), + mock(RerouteService.class), threadPool, mock(ActionFilters.class), mock(IndexNameExpressionResolver.class), diff --git a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java index 86c48c1e183ea..559d3fce9cebf 100644 --- a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java @@ -1153,6 +1153,7 @@ public Collection getRestHeaders() { null, null, mock(ClusterService.class), + null, List.of(), RestExtension.allowAll() ); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/GeoPointFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/GeoPointFieldMapperTests.java index cce44504d4f3e..69cbb1d90b951 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/GeoPointFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/GeoPointFieldMapperTests.java @@ -705,7 +705,7 @@ protected Function loadBlockExpected() { protected static Object asJacksonNumberOutput(long l) { // Cast to int to mimic jackson-core behaviour in NumberOutput.outputLong() - if (l < 0 && l > Integer.MIN_VALUE || l >= 0 && l <= Integer.MAX_VALUE) { + if (l < 0 && l >= Integer.MIN_VALUE || l >= 0 && l <= Integer.MAX_VALUE) { return (int) l; } else { return l; diff --git a/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java b/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java index 96b4df3b856b7..5968be34e985a 100644 --- a/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java @@ -67,14 +67,12 @@ public void setUp() throws Exception { threadpool = new TestThreadPool("file_settings_service_tests"); - var reroute = mock(RerouteService.class); clusterService = spy( new ClusterService( Settings.builder().put(NODE_NAME_SETTING.getKey(), "test").build(), new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), threadpool, - new TaskManager(Settings.EMPTY, threadpool, Set.of()), - () -> reroute + new TaskManager(Settings.EMPTY, threadpool, Set.of()) ) ); @@ -101,7 +99,11 @@ public void setUp() throws Exception { ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - controller = new ReservedClusterStateService(clusterService, List.of(new ReservedClusterSettingsAction(clusterSettings))); + controller = new ReservedClusterStateService( + clusterService, + mock(RerouteService.class), + List.of(new ReservedClusterSettingsAction(clusterSettings)) + ); fileSettingsService = spy(new FileSettingsService(clusterService, controller, env)); } diff --git a/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java b/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java index e8f5a71ad6fcb..fe9401284b9f5 100644 --- a/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java +++ b/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java @@ -85,6 +85,7 @@ public void testOperatorController() throws IOException { ReservedClusterStateService controller = new ReservedClusterStateService( clusterService, + mock(RerouteService.class), List.of(new ReservedClusterSettingsAction(clusterSettings)) ); @@ -147,10 +148,9 @@ public void testUpdateStateTasks() throws Exception { ClusterService clusterService = mock(ClusterService.class); RerouteService rerouteService = mock(RerouteService.class); - when(clusterService.getRerouteService()).thenReturn(rerouteService); ClusterState state = ClusterState.builder(new ClusterName("test")).build(); - ReservedStateUpdateTaskExecutor taskExecutor = new ReservedStateUpdateTaskExecutor(clusterService.getRerouteService()); + ReservedStateUpdateTaskExecutor taskExecutor = new ReservedStateUpdateTaskExecutor(rerouteService); AtomicBoolean successCalled = new AtomicBoolean(false); @@ -362,7 +362,9 @@ public void onFailure(Exception e) {} ); ClusterService clusterService = mock(ClusterService.class); - final var controller = spy(new ReservedClusterStateService(clusterService, List.of(newStateMaker, exceptionThrower))); + final var controller = spy( + new ReservedClusterStateService(clusterService, mock(RerouteService.class), List.of(newStateMaker, exceptionThrower)) + ); var trialRunResult = controller.trialRun("namespace_one", state, chunk, new LinkedHashSet<>(orderedHandlers)); assertEquals(0, trialRunResult.nonStateTransforms().size()); @@ -440,7 +442,7 @@ public void testHandlerOrdering() { ReservedClusterStateHandler> oh3 = makeHandlerHelper("three", List.of("two")); ClusterService clusterService = mock(ClusterService.class); - final var controller = new ReservedClusterStateService(clusterService, List.of(oh1, oh2, oh3)); + final var controller = new ReservedClusterStateService(clusterService, mock(RerouteService.class), List.of(oh1, oh2, oh3)); Collection ordered = controller.orderedStateHandlers(Set.of("one", "two", "three")); assertThat(ordered, contains("two", "three", "one")); @@ -460,7 +462,7 @@ public void testHandlerOrdering() { // Change the second handler so that we create cycle oh2 = makeHandlerHelper("two", List.of("one")); - final var controller1 = new ReservedClusterStateService(clusterService, List.of(oh1, oh2)); + final var controller1 = new ReservedClusterStateService(clusterService, mock(RerouteService.class), List.of(oh1, oh2)); assertThat( expectThrows(IllegalStateException.class, () -> controller1.orderedStateHandlers(Set.of("one", "two"))).getMessage(), @@ -484,6 +486,7 @@ public void testDuplicateHandlerNames() { IllegalStateException.class, () -> new ReservedClusterStateService( clusterService, + mock(RerouteService.class), List.of(new ReservedClusterSettingsAction(clusterSettings), new TestHandler()) ) ).getMessage().startsWith("Duplicate key cluster_settings") @@ -496,7 +499,7 @@ public void testCheckAndReportError() { when(clusterService.state()).thenReturn(state); when(clusterService.createTaskQueue(any(), any(), any())).thenReturn(mockTaskQueue()); - final var controller = spy(new ReservedClusterStateService(clusterService, List.of())); + final var controller = spy(new ReservedClusterStateService(clusterService, mock(RerouteService.class), List.of())); assertNull(controller.checkAndReportError("test", List.of(), null)); verify(controller, times(0)).updateErrorState(any()); @@ -568,7 +571,9 @@ public Map fromXContent(XContentParser parser) throws IOExceptio var orderedHandlers = List.of(exceptionThrower.name(), newStateMaker.name()); ClusterService clusterService = mock(ClusterService.class); - final var controller = spy(new ReservedClusterStateService(clusterService, List.of(newStateMaker, exceptionThrower))); + final var controller = spy( + new ReservedClusterStateService(clusterService, mock(RerouteService.class), List.of(newStateMaker, exceptionThrower)) + ); var trialRunResult = controller.trialRun("namespace_one", state, chunk, new LinkedHashSet<>(orderedHandlers)); @@ -631,7 +636,7 @@ public Map fromXContent(XContentParser parser) throws IOExceptio var chunk = new ReservedStateChunk(chunkMap, new ReservedStateVersion(2L, Version.CURRENT)); ClusterService clusterService = mock(ClusterService.class); - final var controller = spy(new ReservedClusterStateService(clusterService, handlers)); + final var controller = spy(new ReservedClusterStateService(clusterService, mock(RerouteService.class), handlers)); var trialRunResult = controller.trialRun( "namespace_one", diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedSamplerTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedSamplerTests.java index 797ace3f2b37c..6ac538f6c7ce9 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedSamplerTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedSamplerTests.java @@ -199,4 +199,21 @@ public void testDiversifiedSampler_noDocs() throws Exception { indexReader.close(); directory.close(); } + + public void testSupportsParallelCollection() { + DiversifiedAggregationBuilder sampler = new DiversifiedAggregationBuilder("name"); + if (randomBoolean()) { + sampler.field("field"); + } + if (randomBoolean()) { + sampler.maxDocsPerValue(randomIntBetween(1, 1000)); + } + if (randomBoolean()) { + sampler.subAggregation(new TermsAggregationBuilder("name").field("field")); + } + if (randomBoolean()) { + sampler.shardSize(randomIntBetween(1, 1000)); + } + assertFalse(sampler.supportsParallelCollection(null)); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregatorTests.java index 220c863def228..722a510ce381e 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/SamplerAggregatorTests.java @@ -132,4 +132,14 @@ public void testEmptyParentBucket() throws Exception { } } + public void testSupportsParallelCollection() { + SamplerAggregationBuilder sampler = new SamplerAggregationBuilder("name"); + if (randomBoolean()) { + sampler.subAggregation(new TermsAggregationBuilder("name").field("field")); + } + if (randomBoolean()) { + sampler.shardSize(randomIntBetween(1, 1000)); + } + assertFalse(sampler.supportsParallelCollection(null)); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorTests.java index 5c467893179ee..87d4137b5bc59 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorTests.java @@ -20,12 +20,15 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.index.mapper.BinaryFieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.LuceneDocument; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; @@ -60,6 +63,7 @@ import java.util.TreeSet; import static org.elasticsearch.search.aggregations.AggregationBuilders.significantTerms; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorTests.doc; import static org.hamcrest.Matchers.equalTo; public class SignificantTermsAggregatorTests extends AggregatorTestCase { @@ -668,6 +672,26 @@ public void testThreeLayerLong() throws IOException { } } + public void testMatchNoDocsQuery() throws Exception { + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string", randomBoolean(), true, Collections.emptyMap()); + SignificantTermsAggregationBuilder aggregationBuilder = new SignificantTermsAggregationBuilder("_name").field("string"); + CheckedConsumer createIndex = iw -> { + iw.addDocument(doc(fieldType, "a", "b")); + iw.addDocument(doc(fieldType, "", "c", "a")); + iw.addDocument(doc(fieldType, "b", "d")); + iw.addDocument(doc(fieldType, "")); + }; + testCase( + createIndex, + (SignificantStringTerms result) -> { assertEquals(0, result.getBuckets().size()); }, + new AggTestConfig(aggregationBuilder, fieldType).withQuery(new MatchNoDocsQuery()) + ); + + debugTestCase(aggregationBuilder, new MatchNoDocsQuery(), createIndex, (result, impl, debug) -> { + assertEquals(impl, MapStringTermsAggregator.class); + }, fieldType); + } + private void addMixedTextDocs(IndexWriter w) throws IOException { for (int i = 0; i < 10; i++) { Document doc = new Document(); diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java index b0d67879b26a1..204e9025ce9a2 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorTests.java @@ -28,6 +28,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TotalHits; @@ -285,6 +286,26 @@ public void testSimple() throws Exception { }, new AggTestConfig(aggregationBuilder, fieldType)); } + public void testMatchNoDocsQuery() throws Exception { + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string", randomBoolean(), true, Collections.emptyMap()); + TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("_name").field("string"); + CheckedConsumer createIndex = iw -> { + iw.addDocument(doc(fieldType, "a", "b")); + iw.addDocument(doc(fieldType, "", "c", "a")); + iw.addDocument(doc(fieldType, "b", "d")); + iw.addDocument(doc(fieldType, "")); + }; + testCase( + createIndex, + (InternalTerms result) -> { assertEquals(0, result.getBuckets().size()); }, + new AggTestConfig(aggregationBuilder, fieldType).withQuery(new MatchNoDocsQuery()) + ); + + debugTestCase(aggregationBuilder, new MatchNoDocsQuery(), createIndex, (result, impl, debug) -> { + assertEquals(impl, MapStringTermsAggregator.class); + }, fieldType); + } + public void testStringShardMinDocCount() throws IOException { MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("string", true, true, Collections.emptyMap()); for (TermsAggregatorFactory.ExecutionMode executionMode : TermsAggregatorFactory.ExecutionMode.values()) { @@ -419,7 +440,7 @@ public void testDelaysSubAggs() throws Exception { }); } - private List doc(MappedFieldType ft, String... values) { + static List doc(MappedFieldType ft, String... values) { List doc = new ArrayList(); for (String v : values) { BytesRef bytes = new BytesRef(v); diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 03c655d271e7c..a39012e616e37 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1616,7 +1616,6 @@ private final class TestClusterNode { settings, clusterSettings, masterService, - () -> (reason, priority, listener) -> listener.onResponse(null), new ClusterApplierService(node.getName(), settings, clusterSettings, threadPool) { @Override protected PrioritizedEsThreadPoolExecutor createThreadPoolExecutor() { @@ -1726,6 +1725,7 @@ protected void assertSnapshotOrGenericThread() { snapshotsService = new SnapshotsService( settings, clusterService, + (reason, priority, listener) -> listener.onResponse(null), indexNameExpressionResolver, repositoriesService, transportService, diff --git a/test/fixtures/url-fixture/Dockerfile b/test/fixtures/url-fixture/Dockerfile deleted file mode 100644 index d6c1443fa1f85..0000000000000 --- a/test/fixtures/url-fixture/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM openjdk:17.0.2 - -ARG port -ARG workingDir -ARG repositoryDir - -ENV URL_FIXTURE_PORT=${port} -ENV URL_FIXTURE_WORKING_DIR=${workingDir} -ENV URL_FIXTURE_REPO_DIR=${repositoryDir} - -ENTRYPOINT exec java -classpath "/fixture/shared/*" \ - fixture.url.URLFixture "$URL_FIXTURE_PORT" "$URL_FIXTURE_WORKING_DIR" "$URL_FIXTURE_REPO_DIR" - -EXPOSE $port diff --git a/test/fixtures/url-fixture/build.gradle b/test/fixtures/url-fixture/build.gradle index d7d9fd2964c19..d8bcce6ce8211 100644 --- a/test/fixtures/url-fixture/build.gradle +++ b/test/fixtures/url-fixture/build.gradle @@ -6,30 +6,9 @@ * Side Public License, v 1. */ apply plugin: 'elasticsearch.java' -apply plugin: 'elasticsearch.test.fixtures' - description = 'Fixture for URL external service' -tasks.named("test").configure { enabled = false } dependencies { api project(':server') api project(':test:framework') } - -// These directories are shared between the URL repository and the FS repository in integration tests -project.ext { - fsRepositoryDir = file("${testFixturesDir}/fs-repository") -} - -tasks.named("preProcessFixture").configure { - dependsOn "jar", configurations.runtimeClasspath - doLast { - file("${testFixturesDir}/shared").mkdirs() - project.copy { - from jar - from configurations.runtimeClasspath - into "${testFixturesDir}/shared" - } - project.fsRepositoryDir.mkdirs() - } -} diff --git a/test/fixtures/url-fixture/docker-compose.yml b/test/fixtures/url-fixture/docker-compose.yml deleted file mode 100644 index edfc879b1cec3..0000000000000 --- a/test/fixtures/url-fixture/docker-compose.yml +++ /dev/null @@ -1,15 +0,0 @@ -version: '3' -services: - url-fixture: - build: - context: . - args: - port: 80 - workingDir: "/fixture/work" - repositoryDir: "/fixture/repo" - volumes: - - ./testfixtures_shared/shared:/fixture/shared - - ./testfixtures_shared/fs-repository:/fixture/repo - - ./testfixtures_shared/work:/fixture/work - ports: - - "80" diff --git a/test/fixtures/url-fixture/src/main/java/fixture/url/URLFixture.java b/test/fixtures/url-fixture/src/main/java/fixture/url/URLFixture.java index 3f6eed903765a..5192140f1af45 100644 --- a/test/fixtures/url-fixture/src/main/java/fixture/url/URLFixture.java +++ b/test/fixtures/url-fixture/src/main/java/fixture/url/URLFixture.java @@ -7,15 +7,17 @@ */ package fixture.url; -import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.fixture.AbstractHttpFixture; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestRule; import java.io.IOException; import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; import java.util.regex.Matcher; @@ -25,32 +27,17 @@ * This {@link URLFixture} exposes a filesystem directory over HTTP. It is used in repository-url * integration tests to expose a directory created by a regular FS repository. */ -public class URLFixture extends AbstractHttpFixture { +public class URLFixture extends AbstractHttpFixture implements TestRule { private static final Pattern RANGE_PATTERN = Pattern.compile("bytes=(\\d+)-(\\d+)$"); - private final Path repositoryDir; + private final TemporaryFolder temporaryFolder; + private Path repositoryDir; /** * Creates a {@link URLFixture} */ - private URLFixture(final int port, final String workingDir, final String repositoryDir) { - super(workingDir, port); - this.repositoryDir = dir(repositoryDir); - } - - public static void main(String[] args) throws Exception { - if (args == null || args.length != 3) { - throw new IllegalArgumentException("URLFixture "); - } - String workingDirectory = args[1]; - if (Files.exists(dir(workingDirectory)) == false) { - throw new IllegalArgumentException("Configured working directory " + workingDirectory + " does not exist"); - } - String repositoryDirectory = args[2]; - if (Files.exists(dir(repositoryDirectory)) == false) { - throw new IllegalArgumentException("Configured repository directory " + repositoryDirectory + " does not exist"); - } - final URLFixture fixture = new URLFixture(Integer.parseInt(args[0]), workingDirectory, repositoryDirectory); - fixture.listen(InetAddress.getByName("0.0.0.0"), false); + public URLFixture() { + super(); + this.temporaryFolder = new TemporaryFolder(); } @Override @@ -107,8 +94,32 @@ private AbstractHttpFixture.Response handleGetRequest(Request request) throws IO } } - @SuppressForbidden(reason = "Paths#get is fine - we don't have environment here") - private static Path dir(final String dir) { - return Paths.get(dir); + @Override + protected void before() throws Throwable { + this.temporaryFolder.create(); + this.repositoryDir = temporaryFolder.newFolder("repoDir").toPath(); + InetSocketAddress inetSocketAddress = resolveAddress("0.0.0.0", 0); + listen(inetSocketAddress, false); + } + + public String getRepositoryDir() { + if (repositoryDir == null) { + throw new IllegalStateException("Rule has not been started yet"); + } + return repositoryDir.toFile().getAbsolutePath(); + } + + private static InetSocketAddress resolveAddress(String address, int port) { + try { + return new InetSocketAddress(InetAddress.getByName(address), port); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + } + + @Override + protected void after() { + super.stop(); + this.temporaryFolder.delete(); } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java index 2856decf54ee1..52634ae6672b2 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java @@ -35,7 +35,8 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -123,8 +124,9 @@ protected void cancelSearch(String action) { } protected SearchResponse ensureSearchWasCancelled(ActionFuture searchResponse) { + SearchResponse response = null; try { - SearchResponse response = searchResponse.actionGet(); + response = searchResponse.actionGet(); logger.info("Search response {}", response); assertNotEquals("At least one shard should have failed", 0, response.getFailedShards()); for (ShardSearchFailure failure : response.getShardFailures()) { @@ -137,6 +139,8 @@ protected SearchResponse ensureSearchWasCancelled(ActionFuture s assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST)); logger.info("All shards failed with", ex); return null; + } finally { + if (response != null) response.decRef(); } } @@ -153,7 +157,7 @@ public static class ScriptedBlockPlugin extends MockScriptPlugin { private final AtomicInteger hits = new AtomicInteger(); - private final AtomicBoolean shouldBlock = new AtomicBoolean(true); + private final Semaphore shouldBlock = new Semaphore(Integer.MAX_VALUE); private final AtomicReference beforeExecution = new AtomicReference<>(); @@ -162,11 +166,16 @@ public void reset() { } public void disableBlock() { - shouldBlock.set(false); + shouldBlock.release(Integer.MAX_VALUE); } public void enableBlock() { - shouldBlock.set(true); + try { + shouldBlock.acquire(Integer.MAX_VALUE); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } } public void setBeforeExecution(Runnable runnable) { @@ -197,6 +206,23 @@ public Map, Object>> pluginScripts() { ); } + public void logIfBlocked(String logMessage) { + if (shouldBlock.tryAcquire(1) == false) { + LogManager.getLogger(AbstractSearchCancellationTestCase.class).info(logMessage); + } else { + shouldBlock.release(1); + } + } + + public void waitForLock(int timeout, TimeUnit timeUnit) { + try { + assertTrue(shouldBlock.tryAcquire(timeout, timeUnit)); + shouldBlock.release(1); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private Object searchBlockScript(Map params) { final Runnable runnable = beforeExecution.get(); if (runnable != null) { @@ -205,11 +231,7 @@ private Object searchBlockScript(Map params) { LeafStoredFieldsLookup fieldsLookup = (LeafStoredFieldsLookup) params.get("_fields"); LogManager.getLogger(AbstractSearchCancellationTestCase.class).info("Blocking on the document {}", fieldsLookup.get("_id")); hits.incrementAndGet(); - try { - assertBusy(() -> assertFalse(shouldBlock.get())); - } catch (Exception e) { - throw new RuntimeException(e); - } + waitForLock(10, TimeUnit.SECONDS); return true; } @@ -227,15 +249,9 @@ private Object blockScript(Map params) { if (runnable != null) { runnable.run(); } - if (shouldBlock.get()) { - LogManager.getLogger(AbstractSearchCancellationTestCase.class).info("Blocking in reduce"); - } + logIfBlocked("Blocking in reduce"); hits.incrementAndGet(); - try { - assertBusy(() -> assertFalse(shouldBlock.get())); - } catch (Exception e) { - throw new RuntimeException(e); - } + waitForLock(10, TimeUnit.SECONDS); return 42; } @@ -244,15 +260,9 @@ private Object mapBlockScript(Map params) { if (runnable != null) { runnable.run(); } - if (shouldBlock.get()) { - LogManager.getLogger(AbstractSearchCancellationTestCase.class).info("Blocking in map"); - } + logIfBlocked("Blocking in map"); hits.incrementAndGet(); - try { - assertBusy(() -> assertFalse(shouldBlock.get())); - } catch (Exception e) { - throw new RuntimeException(e); - } + waitForLock(10, TimeUnit.SECONDS); return 1; } diff --git a/test/framework/src/main/java/org/elasticsearch/test/fixture/AbstractHttpFixture.java b/test/framework/src/main/java/org/elasticsearch/test/fixture/AbstractHttpFixture.java index 87b8f5f89ffad..8e7fae85e57f5 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/fixture/AbstractHttpFixture.java +++ b/test/framework/src/main/java/org/elasticsearch/test/fixture/AbstractHttpFixture.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.PathUtils; import org.elasticsearch.core.SuppressForbidden; +import org.junit.rules.ExternalResource; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -40,7 +41,7 @@ * Base class for test fixtures that requires a {@link HttpServer} to work. */ @SuppressForbidden(reason = "uses httpserver by design") -public abstract class AbstractHttpFixture { +public abstract class AbstractHttpFixture extends ExternalResource { protected static final Map TEXT_PLAIN_CONTENT_TYPE = contentType("text/plain; charset=utf-8"); protected static final Map JSON_CONTENT_TYPE = contentType("application/json; charset=utf-8"); @@ -51,8 +52,9 @@ public abstract class AbstractHttpFixture { private final AtomicLong requests = new AtomicLong(0); /** Current working directory of the fixture **/ - private final Path workingDirectory; - private final int port; + private Path workingDirectory; + private int port; + private HttpServer httpServer; protected AbstractHttpFixture(final String workingDir) { this(workingDir, 0); @@ -63,6 +65,8 @@ protected AbstractHttpFixture(final String workingDir, int port) { this.workingDirectory = PathUtils.get(Objects.requireNonNull(workingDir)); } + public AbstractHttpFixture() {} + /** * Opens a {@link HttpServer} and start listening on a provided or random port. */ @@ -75,85 +79,100 @@ public final void listen() throws IOException, InterruptedException { */ public final void listen(InetAddress inetAddress, boolean exposePidAndPort) throws IOException, InterruptedException { final InetSocketAddress socketAddress = new InetSocketAddress(inetAddress, port); - final HttpServer httpServer = HttpServer.create(socketAddress, 0); + listenAndWait(socketAddress, exposePidAndPort); + } + public final void listenAndWait(InetSocketAddress socketAddress, boolean exposePidAndPort) throws IOException, InterruptedException { try { - if (exposePidAndPort) { - /// Writes the PID of the current Java process in a `pid` file located in the working directory - writeFile(workingDirectory, "pid", ManagementFactory.getRuntimeMXBean().getName().split("@")[0]); - - final String addressAndPort = addressToString(httpServer.getAddress()); - // Writes the address and port of the http server in a `ports` file located in the working directory - writeFile(workingDirectory, "ports", addressAndPort); - } - - httpServer.createContext("/", exchange -> { - try { - Response response; + listen(socketAddress, exposePidAndPort); + // Wait to be killed + Thread.sleep(Long.MAX_VALUE); + } finally { + stop(); + } + } - // Check if this is a request made by the AntFixture - final String userAgent = exchange.getRequestHeaders().getFirst("User-Agent"); - if (userAgent != null - && userAgent.startsWith("Apache Ant") - && "GET".equals(exchange.getRequestMethod()) - && "/".equals(exchange.getRequestURI().getPath())) { - response = new Response(200, TEXT_PLAIN_CONTENT_TYPE, "OK".getBytes(UTF_8)); + public final void listen(InetSocketAddress socketAddress, boolean exposePidAndPort) throws IOException, InterruptedException { + httpServer = HttpServer.create(socketAddress, 0); + if (exposePidAndPort) { + /// Writes the PID of the current Java process in a `pid` file located in the working directory + writeFile(workingDirectory, "pid", ManagementFactory.getRuntimeMXBean().getName().split("@")[0]); - } else { - try { - final long requestId = requests.getAndIncrement(); - final String method = exchange.getRequestMethod(); + final String addressAndPort = addressToString(httpServer.getAddress()); + // Writes the address and port of the http server in a `ports` file located in the working directory + writeFile(workingDirectory, "ports", addressAndPort); + } - final Map headers = new HashMap<>(); - for (Map.Entry> header : exchange.getRequestHeaders().entrySet()) { - headers.put(header.getKey(), exchange.getRequestHeaders().getFirst(header.getKey())); - } + httpServer.createContext("/", exchange -> { + try { + Response response; + + // Check if this is a request made by the AntFixture + final String userAgent = exchange.getRequestHeaders().getFirst("User-Agent"); + if (userAgent != null + && userAgent.startsWith("Apache Ant") + && "GET".equals(exchange.getRequestMethod()) + && "/".equals(exchange.getRequestURI().getPath())) { + response = new Response(200, TEXT_PLAIN_CONTENT_TYPE, "OK".getBytes(UTF_8)); + + } else { + try { + final long requestId = requests.getAndIncrement(); + final String method = exchange.getRequestMethod(); + + final Map headers = new HashMap<>(); + for (Map.Entry> header : exchange.getRequestHeaders().entrySet()) { + headers.put(header.getKey(), exchange.getRequestHeaders().getFirst(header.getKey())); + } - final ByteArrayOutputStream body = new ByteArrayOutputStream(); - try (InputStream requestBody = exchange.getRequestBody()) { - final byte[] buffer = new byte[1024]; - int i; - while ((i = requestBody.read(buffer, 0, buffer.length)) != -1) { - body.write(buffer, 0, i); - } - body.flush(); + final ByteArrayOutputStream body = new ByteArrayOutputStream(); + try (InputStream requestBody = exchange.getRequestBody()) { + final byte[] buffer = new byte[1024]; + int i; + while ((i = requestBody.read(buffer, 0, buffer.length)) != -1) { + body.write(buffer, 0, i); } + body.flush(); + } - final Request request = new Request(requestId, method, exchange.getRequestURI(), headers, body.toByteArray()); - response = handle(request); + final Request request = new Request(requestId, method, exchange.getRequestURI(), headers, body.toByteArray()); + response = handle(request); - } catch (Exception e) { - final String error = e.getMessage() != null ? e.getMessage() : "Exception when processing the request"; - response = new Response(500, singletonMap("Content-Type", "text/plain; charset=utf-8"), error.getBytes(UTF_8)); - } + } catch (Exception e) { + final String error = e.getMessage() != null ? e.getMessage() : "Exception when processing the request"; + response = new Response(500, singletonMap("Content-Type", "text/plain; charset=utf-8"), error.getBytes(UTF_8)); } + } - if (response == null) { - response = new Response(400, TEXT_PLAIN_CONTENT_TYPE, EMPTY_BYTE); - } + if (response == null) { + response = new Response(400, TEXT_PLAIN_CONTENT_TYPE, EMPTY_BYTE); + } - response.headers.forEach((k, v) -> exchange.getResponseHeaders().put(k, singletonList(v))); - if (response.body.length > 0) { - exchange.sendResponseHeaders(response.status, response.body.length); - exchange.getResponseBody().write(response.body); - } else { - exchange.sendResponseHeaders(response.status, -1); - } - } finally { - exchange.close(); + response.headers.forEach((k, v) -> exchange.getResponseHeaders().put(k, singletonList(v))); + if (response.body.length > 0) { + exchange.sendResponseHeaders(response.status, response.body.length); + exchange.getResponseBody().write(response.body); + } else { + exchange.sendResponseHeaders(response.status, -1); } - }); - httpServer.start(); + } finally { + exchange.close(); + } + }); + httpServer.start(); + } - // Wait to be killed - Thread.sleep(Long.MAX_VALUE); + protected abstract Response handle(Request request) throws IOException; - } finally { + protected void stop() { + if (httpServer != null) { httpServer.stop(0); } } - protected abstract Response handle(Request request) throws IOException; + public String getAddress() { + return "http://127.0.0.1:" + httpServer.getAddress().getPort(); + } @FunctionalInterface public interface RequestHandler { diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java index 2d0abaa5cf4ca..97f0b45fae462 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java @@ -21,6 +21,8 @@ import org.apache.http.ssl.SSLContexts; import org.apache.http.util.EntityUtils; import org.elasticsearch.Build; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.Version; import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction; import org.elasticsearch.action.admin.cluster.repositories.put.PutRepositoryRequest; @@ -109,6 +111,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -2107,7 +2110,53 @@ protected static IndexVersion minimumIndexVersion() throws IOException { return minVersion; } - private static Optional parseLegacyVersion(String version) { + /** + * Returns the minimum transport version among all nodes of the cluster + */ + protected static TransportVersion minimumTransportVersion() throws IOException { + Response response = client.performRequest(new Request("GET", "_nodes")); + ObjectPath objectPath = ObjectPath.createFromResponse(response); + Map nodesAsMap = objectPath.evaluate("nodes"); + + TransportVersion minTransportVersion = null; + for (String id : nodesAsMap.keySet()) { + + var transportVersion = getTransportVersionWithFallback( + objectPath.evaluate("nodes." + id + ".version"), + objectPath.evaluate("nodes." + id + ".transport_version"), + () -> TransportVersions.MINIMUM_COMPATIBLE + ); + if (minTransportVersion == null || minTransportVersion.after(transportVersion)) { + minTransportVersion = transportVersion; + } + } + + assertNotNull(minTransportVersion); + return minTransportVersion; + } + + protected static TransportVersion getTransportVersionWithFallback( + String versionField, + Object transportVersionField, + Supplier fallbackSupplier + ) { + if (transportVersionField instanceof Number transportVersionId) { + return TransportVersion.fromId(transportVersionId.intValue()); + } else if (transportVersionField instanceof String transportVersionString) { + return TransportVersion.fromString(transportVersionString); + } else { // no transport_version field + // The response might be from a node <8.8.0, but about a node >=8.8.0 + // In that case the transport_version field won't exist. Use version, but only for <8.8.0: after that versions diverge. + var version = parseLegacyVersion(versionField); + assert version.isPresent(); + if (version.get().before(Version.V_8_8_0)) { + return TransportVersion.fromId(version.get().id); + } + } + return fallbackSupplier.get(); + } + + protected static Optional parseLegacyVersion(String version) { var semanticVersionMatcher = SEMANTIC_VERSION_PATTERN.matcher(version); if (semanticVersionMatcher.matches()) { return Optional.of(Version.fromString(semanticVersionMatcher.group(1))); diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/RestTestLegacyFeatures.java b/test/framework/src/main/java/org/elasticsearch/test/rest/RestTestLegacyFeatures.java index 60653d32e1e38..bd19757bac438 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/RestTestLegacyFeatures.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/RestTestLegacyFeatures.java @@ -31,6 +31,10 @@ public class RestTestLegacyFeatures implements FeatureSpecification { "indices.delete_template_multiple_names_supported" ); + // QA - rolling upgrade tests + public static final NodeFeature SECURITY_UPDATE_API_KEY = new NodeFeature("security.api_key_update"); + public static final NodeFeature SECURITY_BULK_UPDATE_API_KEY = new NodeFeature("security.api_key_bulk_update"); + @Override public Map getHistoricalFeatures() { return Map.ofEntries( @@ -39,7 +43,9 @@ public Map getHistoricalFeatures() { entry(HIDDEN_INDICES_SUPPORTED, Version.V_7_7_0), entry(COMPONENT_TEMPLATE_SUPPORTED, Version.V_7_8_0), entry(DELETE_TEMPLATE_MULTIPLE_NAMES_SUPPORTED, Version.V_7_13_0), - entry(ML_STATE_RESET_FALLBACK_ON_DISABLED, Version.V_8_7_0) + entry(ML_STATE_RESET_FALLBACK_ON_DISABLED, Version.V_8_7_0), + entry(SECURITY_UPDATE_API_KEY, Version.V_8_4_0), + entry(SECURITY_BULK_UPDATE_API_KEY, Version.V_8_5_0) ); } } diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index 847779c9066c4..695e96850e8e1 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -681,7 +681,7 @@ void populateAndRead( final List gaps = tracker.waitForRange( rangeToWrite, rangeToRead, - ActionListener.runBefore(listener, resource::close).delegateFailureAndWrap((l, success) -> { + ActionListener.runAfter(listener, resource::close).delegateFailureAndWrap((l, success) -> { var ioRef = io; assert regionOwners.get(ioRef) == this; final int start = Math.toIntExact(rangeToRead.start()); diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index eb1271edd3b06..4aa2e145228b8 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -73,6 +73,8 @@ exports org.elasticsearch.xpack.core.ilm.step.info; exports org.elasticsearch.xpack.core.ilm; exports org.elasticsearch.xpack.core.indexing; + exports org.elasticsearch.xpack.core.inference.action; + exports org.elasticsearch.xpack.core.inference.results; exports org.elasticsearch.xpack.core.logstash; exports org.elasticsearch.xpack.core.ml.action; exports org.elasticsearch.xpack.core.ml.annotations; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceModelAction.java similarity index 97% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceModelAction.java index 4062946935b2e..1324471f7c0ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceModelAction.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.action; +package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java similarity index 95% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java index a9b1fb32a7471..0343206994d2c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.action; +package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; @@ -38,9 +38,9 @@ public static class Request extends AcknowledgedRequest parsedResults) { + public static InferenceServiceResults transformToServiceResults(List parsedResults) { if (parsedResults.isEmpty()) { throw new ElasticsearchStatusException( "Failed to transform results to response format, expected a non-empty list, please remove and re-add the service", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java index 45b9474cebcdc..e6e4ea1001f68 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.action; +package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java similarity index 94% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java index b5d6b8483138a..8f03a75c61c11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -39,8 +39,7 @@ * ] * } * - * This class represents the way that the {@link org.elasticsearch.xpack.inference.services.openai.OpenAiService} - * formatted the response for the embeddings type. This represents what was returned prior to the + * Legacy text embedding results represents what was returned prior to the * {@link org.elasticsearch.TransportVersions#INFERENCE_SERVICE_RESULTS_ADDED} version. * @deprecated use {@link TextEmbeddingResults} instead */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java similarity index 99% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java index 0e0299a5e12fd..20279e82d6c09 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 74f94e1aea17d..7a7ccab2b4daa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CoordinatedInferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CoordinatedInferenceAction.java new file mode 100644 index 0000000000000..8ff0c1179ea61 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CoordinatedInferenceAction.java @@ -0,0 +1,242 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class CoordinatedInferenceAction extends ActionType { + + public static final CoordinatedInferenceAction INSTANCE = new CoordinatedInferenceAction(); + public static final String NAME = "cluster:internal/xpack/ml/coordinatedinference"; + + public CoordinatedInferenceAction() { + super(NAME, InferModelAction.Response::new); + } + + public static class Request extends ActionRequest { + + public enum RequestModelType { + INFERENCE_SERVICE_MODEL, + ML_NODE_PYTORCH_MODEL, + BOOSTED_TREE_MODEL, + NLP_MODEL, // Either an inference service model or ml pytorch model but not a boosted tree model + UNKNOWN + }; + + public static Request forTextInput( + String modelId, + List inputs, + @Nullable InferenceConfigUpdate inferenceConfigUpdate, + @Nullable Boolean previouslyLicensed, + @Nullable TimeValue inferenceTimeout + ) { + return new Request( + modelId, + inputs, + null, + null, + inferenceConfigUpdate, + previouslyLicensed, + inferenceTimeout, + false, // not high priority + RequestModelType.NLP_MODEL + ); + } + + public static Request forMapInput( + String modelId, + List> objectsToInfer, + @Nullable InferenceConfigUpdate inferenceConfigUpdate, + @Nullable Boolean previouslyLicensed, + @Nullable TimeValue inferenceTimeout, + RequestModelType modelType + ) { + return new Request( + modelId, + null, + null, + objectsToInfer, + inferenceConfigUpdate, + previouslyLicensed, + inferenceTimeout, + false, // not high priority, + modelType + ); + } + + private final String modelId; + private final RequestModelType requestModelType; + // For inference services or cluster hosted NLP models + private final List inputs; + // _inference settings + private final Map taskSettings; + // In cluster model options + private final TimeValue inferenceTimeout; + private final Boolean previouslyLicensed; + private final InferenceConfigUpdate inferenceConfigUpdate; + private boolean highPriority; + private TrainedModelPrefixStrings.PrefixType prefixType = TrainedModelPrefixStrings.PrefixType.NONE; + // DFA models only + private final List> objectsToInfer; + + private Request( + String modelId, + @Nullable List inputs, + @Nullable Map taskSettings, + @Nullable List> objectsToInfer, + @Nullable InferenceConfigUpdate inferenceConfigUpdate, + @Nullable Boolean previouslyLicensed, + @Nullable TimeValue inferenceTimeout, + boolean highPriority, + RequestModelType requestModelType + ) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id"); + this.inputs = inputs; + this.taskSettings = taskSettings; + this.objectsToInfer = objectsToInfer; + this.inferenceConfigUpdate = inferenceConfigUpdate; + this.previouslyLicensed = previouslyLicensed; + this.inferenceTimeout = inferenceTimeout; + this.highPriority = highPriority; + this.requestModelType = requestModelType; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.requestModelType = in.readEnum(RequestModelType.class); + this.inputs = in.readOptionalStringCollectionAsList(); + this.taskSettings = in.readMap(); + this.objectsToInfer = in.readOptionalCollectionAsList(StreamInput::readMap); + this.inferenceConfigUpdate = in.readOptionalNamedWriteable(InferenceConfigUpdate.class); + this.previouslyLicensed = in.readOptionalBoolean(); + this.inferenceTimeout = in.readOptionalTimeValue(); + this.highPriority = in.readBoolean(); + } + + public String getModelId() { + return modelId; + } + + public List getInputs() { + return inputs; + } + + public Map getTaskSettings() { + return taskSettings; + } + + public List> getObjectsToInfer() { + return objectsToInfer; + } + + public InferenceConfigUpdate getInferenceConfigUpdate() { + return inferenceConfigUpdate; + } + + public Boolean getPreviouslyLicensed() { + return previouslyLicensed; + } + + public TimeValue getInferenceTimeout() { + return inferenceTimeout; + } + + public boolean getHighPriority() { + return highPriority; + } + + public void setHighPriority(boolean highPriority) { + this.highPriority = highPriority; + } + + public boolean hasInferenceConfig() { + return inferenceConfigUpdate != null; + } + + public boolean hasObjects() { + return objectsToInfer != null; + } + + public void setPrefixType(TrainedModelPrefixStrings.PrefixType prefixType) { + this.prefixType = prefixType; + } + + public TrainedModelPrefixStrings.PrefixType getPrefixType() { + return prefixType; + } + + public RequestModelType getRequestModelType() { + return requestModelType; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeEnum(requestModelType); + out.writeOptionalStringCollection(inputs); + out.writeGenericMap(taskSettings); + out.writeOptionalCollection(objectsToInfer, StreamOutput::writeGenericMap); + out.writeOptionalNamedWriteable(inferenceConfigUpdate); + out.writeOptionalBoolean(previouslyLicensed); + out.writeOptionalTimeValue(inferenceTimeout); + out.writeBoolean(highPriority); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(modelId, request.modelId) + && Objects.equals(requestModelType, request.requestModelType) + && Objects.equals(inputs, request.inputs) + && Objects.equals(taskSettings, request.taskSettings) + && Objects.equals(objectsToInfer, request.objectsToInfer) + && Objects.equals(inferenceConfigUpdate, request.inferenceConfigUpdate) + && Objects.equals(previouslyLicensed, request.previouslyLicensed) + && Objects.equals(inferenceTimeout, request.inferenceTimeout) + && Objects.equals(highPriority, request.highPriority); + } + + @Override + public int hashCode() { + return Objects.hash( + modelId, + requestModelType, + inputs, + taskSettings, + objectsToInfer, + inferenceConfigUpdate, + previouslyLicensed, + inferenceTimeout, + highPriority + ); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java index 0ba74df1f8d54..c098b13fd1deb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/EmptyConfigUpdate.java @@ -19,6 +19,8 @@ public class EmptyConfigUpdate implements InferenceConfigUpdate { public static final String NAME = "empty"; + public static final EmptyConfigUpdate INSTANCE = new EmptyConfigUpdate(); + public static MlConfigVersion minimumSupportedVersion() { return MlConfigVersion.V_7_9_0; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java index 5ce5b0188771b..89dcf746d7927 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearnToRankConfig.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryRewriteContext; @@ -163,6 +164,11 @@ public int hashCode() { return Objects.hash(super.hashCode(), featureExtractorBuilders); } + @Override + public final String toString() { + return Strings.toString(this); + } + @Override public boolean isTargetTypeSupported(TargetType targetType) { return TargetType.REGRESSION.equals(targetType); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CoordinatedInferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CoordinatedInferenceActionRequestTests.java new file mode 100644 index 0000000000000..bd8e0ad96f21a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CoordinatedInferenceActionRequestTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class CoordinatedInferenceActionRequestTests extends AbstractWireSerializingTestCase { + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + @Override + protected Writeable.Reader instanceReader() { + return CoordinatedInferenceAction.Request::new; + } + + @Override + protected CoordinatedInferenceAction.Request createTestInstance() { + return switch (randomIntBetween(0, 1)) { + case 0 -> { + var inferenceConfig = randomBoolean() ? null : InferModelActionRequestTests.randomInferenceConfigUpdate(); + var previouslyLicensed = randomBoolean() ? null : randomBoolean(); + var inferenceTimeout = randomBoolean() ? null : TimeValue.parseTimeValue(randomTimeValue(), null, "timeout"); + var highPriority = randomBoolean(); + + var request = CoordinatedInferenceAction.Request.forTextInput( + randomAlphaOfLength(6), + List.of(randomAlphaOfLength(6)), + inferenceConfig, + previouslyLicensed, + inferenceTimeout + ); + request.setHighPriority(highPriority); + yield request; + } + case 1 -> { + var inferenceConfig = randomBoolean() ? null : InferModelActionRequestTests.randomInferenceConfigUpdate(); + var previouslyLicensed = randomBoolean() ? null : randomBoolean(); + var inferenceTimeout = randomBoolean() ? null : TimeValue.parseTimeValue(randomTimeValue(), null, "timeout"); + var highPriority = randomBoolean(); + var modelType = randomFrom(CoordinatedInferenceAction.Request.RequestModelType.values()); + + var request = CoordinatedInferenceAction.Request.forMapInput( + randomAlphaOfLength(6), + Stream.generate(CoordinatedInferenceActionRequestTests::randomMap).limit(randomInt(5)).collect(Collectors.toList()), + inferenceConfig, + previouslyLicensed, + inferenceTimeout, + modelType + ); + request.setHighPriority(highPriority); + yield request; + } + default -> throw new UnsupportedOperationException(); + }; + } + + private static Map randomMap() { + return Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10))); + } + + @Override + protected CoordinatedInferenceAction.Request mutateInstance(CoordinatedInferenceAction.Request instance) throws IOException { + return null; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index fcfc396313016..b33b64ccf69d7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -130,7 +130,7 @@ protected Request mutateInstance(Request instance) { return r; } - private static InferenceConfigUpdate randomInferenceConfigUpdate() { + public static InferenceConfigUpdate randomInferenceConfigUpdate() { return randomFrom( RegressionConfigUpdateTests.randomRegressionConfigUpdate(), ClassificationConfigUpdateTests.randomClassificationConfigUpdate(), diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/330_connector_update_pipeline.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/330_connector_update_pipeline.yml new file mode 100644 index 0000000000000..8d0bfe0232932 --- /dev/null +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/330_connector_update_pipeline.yml @@ -0,0 +1,64 @@ +setup: + - skip: + version: " - 8.11.99" + reason: Introduced in 8.12.0 + + - do: + connector.put: + connector_id: test-connector + body: + index_name: search-1-test + name: my-connector + language: pl + is_native: false + service_type: super-connector + +--- +"Update Connector Pipeline - Connector doesn't exist": + - do: + catch: "missing" + connector.update_pipeline: + connector_id: test-non-existent-connector + body: + pipeline: + extract_binary_content: true + name: test-pipeline + reduce_whitespace: true + run_ml_inference: false + +--- +"Update Connector Pipeline": + - do: + connector.update_pipeline: + connector_id: test-connector + body: + pipeline: + extract_binary_content: true + name: test-pipeline + reduce_whitespace: true + run_ml_inference: false + + - match: { result: updated } + + - do: + connector.get: + connector_id: test-connector + + - match: { pipeline.extract_binary_content: true } + - match: { pipeline.name: test-pipeline } + - match: { pipeline.reduce_whitespace: true } + - match: { pipeline.run_ml_inference: false } + +--- +"Update Connector Pipeline - Required fields are missing": + - do: + catch: "bad_request" + connector.update_pipeline: + connector_id: test-connector + body: + pipeline: + extract_binary_content: true + name: test-pipeline + run_ml_inference: false + + diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/331_connector_update_scheduling.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/331_connector_update_scheduling.yml index 21d588f538fc5..e8e3fa0e87068 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/331_connector_update_scheduling.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/331_connector_update_scheduling.yml @@ -45,7 +45,7 @@ setup: - match: { scheduling.incremental.interval: "3 0 0 * * ?" } --- -"Update Connector Scheduling - 404 status code returned when connector doesn't exist": +"Update Connector Scheduling - Connector doesn't exist": - do: catch: "missing" connector.update_scheduling: @@ -63,7 +63,7 @@ setup: interval: 3 0 0 * * ? --- -"Update Connector Scheduling - 400 status code returned when required fields are missing": +"Update Connector Scheduling - Required fields are missing": - do: catch: "bad_request" connector.update_scheduling: @@ -75,7 +75,7 @@ setup: interval: 3 0 0 * * ? --- -"Update Connector Scheduling - 400 status code returned with wrong CRON expression": +"Update Connector Scheduling - Wrong CRON expression": - do: catch: "bad_request" connector.update_scheduling: diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml new file mode 100644 index 0000000000000..c5634365db3ec --- /dev/null +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml @@ -0,0 +1,278 @@ +setup: + - skip: + version: " - 8.11.99" + reason: Introduced in 8.12.0 + + - do: + connector.put: + connector_id: test-connector + body: + index_name: search-1-test + name: my-connector + language: pl + is_native: false + service_type: super-connector +--- +"Update Connector Filtering": + - do: + connector.update_filtering: + connector_id: test-connector + body: + filtering: + - active: + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: RULE-ACTIVE-0 + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + domain: DEFAULT + draft: + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: RULE-DRAFT-0 + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + - active: + advanced_snippet: + created_at: "2021-05-25T12:30:00.000Z" + updated_at: "2021-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2021-05-25T12:30:00.000Z" + field: _ + id: RULE-ACTIVE-1 + order: 0 + policy: include + rule: regex + updated_at: "2021-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + domain: TEST + draft: + advanced_snippet: + created_at: "2021-05-25T12:30:00.000Z" + updated_at: "2021-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2021-05-25T12:30:00.000Z" + field: _ + id: RULE-DRAFT-1 + order: 0 + policy: exclude + rule: regex + updated_at: "2021-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + + - match: { result: updated } + + - do: + connector.get: + connector_id: test-connector + + - match: { filtering.0.domain: DEFAULT } + - match: { filtering.0.active.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } + - match: { filtering.0.active.rules.0.id: "RULE-ACTIVE-0" } + - match: { filtering.0.draft.rules.0.id: "RULE-DRAFT-0" } + + - match: { filtering.1.domain: TEST } + - match: { filtering.1.active.advanced_snippet.created_at: "2021-05-25T12:30:00.000Z" } + - match: { filtering.1.active.rules.0.id: "RULE-ACTIVE-1" } + - match: { filtering.1.draft.rules.0.id: "RULE-DRAFT-1" } + +--- +"Update Connector Filtering - Connector doesn't exist": + - do: + catch: "missing" + connector.update_filtering: + connector_id: test-non-existent-connector + body: + filtering: + - active: + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: RULE-ACTIVE-0 + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + domain: DEFAULT + draft: + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: RULE-DRAFT-0 + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + - active: + advanced_snippet: + created_at: "2021-05-25T12:30:00.000Z" + updated_at: "2021-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2021-05-25T12:30:00.000Z" + field: _ + id: RULE-ACTIVE-1 + order: 0 + policy: include + rule: regex + updated_at: "2021-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + domain: TEST + draft: + advanced_snippet: + created_at: "2021-05-25T12:30:00.000Z" + updated_at: "2021-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2021-05-25T12:30:00.000Z" + field: _ + id: RULE-DRAFT-1 + order: 0 + policy: exclude + rule: regex + updated_at: "2021-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + +--- +"Update Connector Filtering - Required fields are missing": + - do: + catch: "bad_request" + connector.update_filtering: + connector_id: test-connector + body: + filtering: + - domain: some_domain + + - match: + status: 400 + +--- +"Update Connector Filtering - Wrong datetime expression": + - do: + catch: "bad_request" + connector.update_filtering: + connector_id: test-connector + body: + filtering: + - active: + advanced_snippet: + created_at: "this-is-not-a-datetime-!!!!" + updated_at: "2023-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: RULE-ACTIVE-0 + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + domain: DEFAULT + draft: + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: RULE-DRAFT-0 + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + - active: + advanced_snippet: + created_at: "2021-05-25T12:30:00.000Z" + updated_at: "2021-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2021-05-25T12:30:00.000Z" + field: _ + id: RULE-ACTIVE-1 + order: 0 + policy: include + rule: regex + updated_at: "2021-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid + domain: TEST + draft: + advanced_snippet: + created_at: "2021-05-25T12:30:00.000Z" + updated_at: "2021-05-25T12:30:00.000Z" + value: {} + rules: + - created_at: "2021-05-25T12:30:00.000Z" + field: _ + id: RULE-DRAFT-1 + order: 0 + policy: exclude + rule: regex + updated_at: "2021-05-25T12:30:00.000Z" + value: ".*" + validation: + errors: [] + state: valid diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/420_connector_sync_job_check_in.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/420_connector_sync_job_check_in.yml new file mode 100644 index 0000000000000..9ef37f4a9fe60 --- /dev/null +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/420_connector_sync_job_check_in.yml @@ -0,0 +1,36 @@ +setup: + - skip: + version: " - 8.11.99" + reason: Introduced in 8.12.0 + - do: + connector.put: + connector_id: test-connector + body: + index_name: search-test + name: my-connector + language: de + is_native: false + service_type: super-connector + +--- +"Check in a Connector Sync Job": + - do: + connector_sync_job.post: + body: + id: test-connector + job_type: full + trigger_method: on_demand + - set: { id: sync-job-id-to-check-in } + - do: + connector_sync_job.check_in: + connector_sync_job_id: $sync-job-id-to-check-in + + - match: { acknowledged: true } + + +--- +"Check in a Connector Sync Job - Connector Sync Job does not exist": + - do: + connector_sync_job.check_in: + connector_sync_job_id: test-nonexistent-connector-sync-job-id + catch: missing diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/430_connector_sync_job_cancel.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/430_connector_sync_job_cancel.yml new file mode 100644 index 0000000000000..e9c612cbf9f27 --- /dev/null +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/430_connector_sync_job_cancel.yml @@ -0,0 +1,36 @@ +setup: + - skip: + version: " - 8.11.99" + reason: Introduced in 8.12.0 + - do: + connector.put: + connector_id: test-connector + body: + index_name: search-test + name: my-connector + language: de + is_native: false + service_type: super-connector + +--- +"Cancel a Connector Sync Job": + - do: + connector_sync_job.post: + body: + id: test-connector + job_type: full + trigger_method: on_demand + - set: { id: sync-job-id-to-cancel } + - do: + connector_sync_job.cancel: + connector_sync_job_id: $sync-job-id-to-cancel + + - match: { acknowledged: true } + + +--- +"Cancel a Connector Sync Job - Connector Sync Job does not exist": + - do: + connector_sync_job.check_in: + connector_sync_job_id: test-nonexistent-connector-sync-job-id + catch: missing diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/EnterpriseSearch.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/EnterpriseSearch.java index 26ac6dc9b939d..3402c3a8b9d7b 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/EnterpriseSearch.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/EnterpriseSearch.java @@ -50,17 +50,29 @@ import org.elasticsearch.xpack.application.connector.action.RestGetConnectorAction; import org.elasticsearch.xpack.application.connector.action.RestListConnectorAction; import org.elasticsearch.xpack.application.connector.action.RestPutConnectorAction; +import org.elasticsearch.xpack.application.connector.action.RestUpdateConnectorFilteringAction; +import org.elasticsearch.xpack.application.connector.action.RestUpdateConnectorPipelineAction; import org.elasticsearch.xpack.application.connector.action.RestUpdateConnectorSchedulingAction; import org.elasticsearch.xpack.application.connector.action.TransportDeleteConnectorAction; import org.elasticsearch.xpack.application.connector.action.TransportGetConnectorAction; import org.elasticsearch.xpack.application.connector.action.TransportListConnectorAction; import org.elasticsearch.xpack.application.connector.action.TransportPutConnectorAction; +import org.elasticsearch.xpack.application.connector.action.TransportUpdateConnectorFilteringAction; +import org.elasticsearch.xpack.application.connector.action.TransportUpdateConnectorPipelineAction; import org.elasticsearch.xpack.application.connector.action.TransportUpdateConnectorSchedulingAction; +import org.elasticsearch.xpack.application.connector.action.UpdateConnectorFilteringAction; +import org.elasticsearch.xpack.application.connector.action.UpdateConnectorPipelineAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorSchedulingAction; +import org.elasticsearch.xpack.application.connector.syncjob.action.CancelConnectorSyncJobAction; +import org.elasticsearch.xpack.application.connector.syncjob.action.CheckInConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.DeleteConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.PostConnectorSyncJobAction; +import org.elasticsearch.xpack.application.connector.syncjob.action.RestCancelConnectorSyncJobAction; +import org.elasticsearch.xpack.application.connector.syncjob.action.RestCheckInConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.RestDeleteConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.RestPostConnectorSyncJobAction; +import org.elasticsearch.xpack.application.connector.syncjob.action.TransportCancelConnectorSyncJobAction; +import org.elasticsearch.xpack.application.connector.syncjob.action.TransportCheckInConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.TransportDeleteConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.TransportPostConnectorSyncJobAction; import org.elasticsearch.xpack.application.rules.QueryRulesConfig; @@ -178,16 +190,20 @@ protected XPackLicenseState getLicenseState() { if (ConnectorAPIFeature.isEnabled()) { actionHandlers.addAll( List.of( - // Connector API + // Connectors API new ActionHandler<>(DeleteConnectorAction.INSTANCE, TransportDeleteConnectorAction.class), new ActionHandler<>(GetConnectorAction.INSTANCE, TransportGetConnectorAction.class), new ActionHandler<>(ListConnectorAction.INSTANCE, TransportListConnectorAction.class), new ActionHandler<>(PutConnectorAction.INSTANCE, TransportPutConnectorAction.class), + new ActionHandler<>(UpdateConnectorFilteringAction.INSTANCE, TransportUpdateConnectorFilteringAction.class), + new ActionHandler<>(UpdateConnectorPipelineAction.INSTANCE, TransportUpdateConnectorPipelineAction.class), new ActionHandler<>(UpdateConnectorSchedulingAction.INSTANCE, TransportUpdateConnectorSchedulingAction.class), // SyncJob API new ActionHandler<>(PostConnectorSyncJobAction.INSTANCE, TransportPostConnectorSyncJobAction.class), - new ActionHandler<>(DeleteConnectorSyncJobAction.INSTANCE, TransportDeleteConnectorSyncJobAction.class) + new ActionHandler<>(DeleteConnectorSyncJobAction.INSTANCE, TransportDeleteConnectorSyncJobAction.class), + new ActionHandler<>(CheckInConnectorSyncJobAction.INSTANCE, TransportCheckInConnectorSyncJobAction.class), + new ActionHandler<>(CancelConnectorSyncJobAction.INSTANCE, TransportCancelConnectorSyncJobAction.class) ) ); } @@ -238,16 +254,20 @@ public List getRestHandlers( if (ConnectorAPIFeature.isEnabled()) { restHandlers.addAll( List.of( - // Connector API + // Connectors API new RestDeleteConnectorAction(), new RestGetConnectorAction(), new RestListConnectorAction(), new RestPutConnectorAction(), + new RestUpdateConnectorFilteringAction(), + new RestUpdateConnectorPipelineAction(), new RestUpdateConnectorSchedulingAction(), // SyncJob API new RestPostConnectorSyncJobAction(), - new RestDeleteConnectorSyncJobAction() + new RestDeleteConnectorSyncJobAction(), + new RestCancelConnectorSyncJobAction(), + new RestCheckInConnectorSyncJobAction() ) ); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java index d632a28d3f858..749e8c2e9dd87 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java @@ -31,6 +31,8 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.application.connector.action.UpdateConnectorFilteringAction; +import org.elasticsearch.xpack.application.connector.action.UpdateConnectorPipelineAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorSchedulingAction; import java.util.Arrays; @@ -169,6 +171,66 @@ public void onFailure(Exception e) { } } + /** + * Updates the {@link ConnectorFiltering} property of a {@link Connector}. + * + * @param request Request for updating connector filtering property. + * @param listener Listener to respond to a successful response or an error. + */ + public void updateConnectorFiltering(UpdateConnectorFilteringAction.Request request, ActionListener listener) { + try { + String connectorId = request.getConnectorId(); + final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_INDEX_NAME, connectorId).doc( + new IndexRequest(CONNECTOR_INDEX_NAME).opType(DocWriteRequest.OpType.INDEX) + .id(connectorId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(request.toXContent(jsonBuilder(), ToXContent.EMPTY_PARAMS)) + ); + clientWithOrigin.update( + updateRequest, + new DelegatingIndexNotFoundActionListener<>(connectorId, listener, (l, updateResponse) -> { + if (updateResponse.getResult() == UpdateResponse.Result.NOT_FOUND) { + l.onFailure(new ResourceNotFoundException(connectorId)); + return; + } + l.onResponse(updateResponse); + }) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Updates the {@link ConnectorIngestPipeline} property of a {@link Connector}. + * + * @param request Request for updating connector ingest pipeline property. + * @param listener Listener to respond to a successful response or an error. + */ + public void updateConnectorPipeline(UpdateConnectorPipelineAction.Request request, ActionListener listener) { + try { + String connectorId = request.getConnectorId(); + final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_INDEX_NAME, connectorId).doc( + new IndexRequest(CONNECTOR_INDEX_NAME).opType(DocWriteRequest.OpType.INDEX) + .id(connectorId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(request.toXContent(jsonBuilder(), ToXContent.EMPTY_PARAMS)) + ); + clientWithOrigin.update( + updateRequest, + new DelegatingIndexNotFoundActionListener<>(connectorId, listener, (l, updateResponse) -> { + if (updateResponse.getResult() == UpdateResponse.Result.NOT_FOUND) { + l.onFailure(new ResourceNotFoundException(connectorId)); + return; + } + l.onResponse(updateResponse); + }) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + /** * Updates the {@link ConnectorScheduling} property of a {@link Connector}. * diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringAction.java new file mode 100644 index 0000000000000..63ae3e81fe563 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringAction.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.application.EnterpriseSearch; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.PUT; + +public class RestUpdateConnectorFilteringAction extends BaseRestHandler { + + @Override + public String getName() { + return "connector_update_filtering_action"; + } + + @Override + public List routes() { + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_filtering")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { + UpdateConnectorFilteringAction.Request request = UpdateConnectorFilteringAction.Request.fromXContentBytes( + restRequest.param("connector_id"), + restRequest.content(), + restRequest.getXContentType() + ); + return channel -> client.execute( + UpdateConnectorFilteringAction.INSTANCE, + request, + new RestToXContentListener<>(channel, UpdateConnectorFilteringAction.Response::status, r -> null) + ); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineAction.java new file mode 100644 index 0000000000000..ba83bd42dac11 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineAction.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.application.EnterpriseSearch; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.PUT; + +public class RestUpdateConnectorPipelineAction extends BaseRestHandler { + + @Override + public String getName() { + return "connector_update_pipeline_action"; + } + + @Override + public List routes() { + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_pipeline")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { + UpdateConnectorPipelineAction.Request request = UpdateConnectorPipelineAction.Request.fromXContentBytes( + restRequest.param("connector_id"), + restRequest.content(), + restRequest.getXContentType() + ); + return channel -> client.execute( + UpdateConnectorPipelineAction.INSTANCE, + request, + new RestToXContentListener<>(channel, UpdateConnectorPipelineAction.Response::status, r -> null) + ); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java new file mode 100644 index 0000000000000..e871eb4bb79e5 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.ConnectorIndexService; + +public class TransportUpdateConnectorFilteringAction extends HandledTransportAction< + UpdateConnectorFilteringAction.Request, + UpdateConnectorFilteringAction.Response> { + + protected final ConnectorIndexService connectorIndexService; + + @Inject + public TransportUpdateConnectorFilteringAction( + TransportService transportService, + ClusterService clusterService, + ActionFilters actionFilters, + Client client + ) { + super( + UpdateConnectorFilteringAction.NAME, + transportService, + actionFilters, + UpdateConnectorFilteringAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.connectorIndexService = new ConnectorIndexService(client); + } + + @Override + protected void doExecute( + Task task, + UpdateConnectorFilteringAction.Request request, + ActionListener listener + ) { + connectorIndexService.updateConnectorFiltering( + request, + listener.map(r -> new UpdateConnectorFilteringAction.Response(r.getResult())) + ); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorPipelineAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorPipelineAction.java new file mode 100644 index 0000000000000..c54d3db1215bc --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorPipelineAction.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.ConnectorIndexService; + +public class TransportUpdateConnectorPipelineAction extends HandledTransportAction< + UpdateConnectorPipelineAction.Request, + UpdateConnectorPipelineAction.Response> { + + protected final ConnectorIndexService connectorIndexService; + + @Inject + public TransportUpdateConnectorPipelineAction( + TransportService transportService, + ClusterService clusterService, + ActionFilters actionFilters, + Client client + ) { + super( + UpdateConnectorPipelineAction.NAME, + transportService, + actionFilters, + UpdateConnectorPipelineAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.connectorIndexService = new ConnectorIndexService(client); + } + + @Override + protected void doExecute( + Task task, + UpdateConnectorPipelineAction.Request request, + ActionListener listener + ) { + connectorIndexService.updateConnectorPipeline( + request, + listener.map(r -> new UpdateConnectorPipelineAction.Response(r.getResult())) + ); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java new file mode 100644 index 0000000000000..68c644cb9d9db --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.application.connector.Connector; +import org.elasticsearch.xpack.application.connector.ConnectorFiltering; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +public class UpdateConnectorFilteringAction extends ActionType { + + public static final UpdateConnectorFilteringAction INSTANCE = new UpdateConnectorFilteringAction(); + public static final String NAME = "cluster:admin/xpack/connector/update_filtering"; + + public UpdateConnectorFilteringAction() { + super(NAME, UpdateConnectorFilteringAction.Response::new); + } + + public static class Request extends ActionRequest implements ToXContentObject { + + private final String connectorId; + private final List filtering; + + public Request(String connectorId, List filtering) { + this.connectorId = connectorId; + this.filtering = filtering; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.connectorId = in.readString(); + this.filtering = in.readOptionalCollectionAsList(ConnectorFiltering::new); + } + + public String getConnectorId() { + return connectorId; + } + + public List getFiltering() { + return filtering; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + + if (Strings.isNullOrEmpty(connectorId)) { + validationException = addValidationError("[connector_id] cannot be null or empty.", validationException); + } + + if (filtering == null) { + validationException = addValidationError("[filtering] cannot be null.", validationException); + } + + return validationException; + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "connector_update_filtering_request", + false, + ((args, connectorId) -> new UpdateConnectorFilteringAction.Request(connectorId, (List) args[0])) + ); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> ConnectorFiltering.fromXContent(p), Connector.FILTERING_FIELD); + } + + public static UpdateConnectorFilteringAction.Request fromXContentBytes( + String connectorId, + BytesReference source, + XContentType xContentType + ) { + try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { + return UpdateConnectorFilteringAction.Request.fromXContent(parser, connectorId); + } catch (IOException e) { + throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); + } + } + + public static UpdateConnectorFilteringAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { + return PARSER.parse(parser, connectorId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + builder.field(Connector.FILTERING_FIELD.getPreferredName(), filtering); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(connectorId); + out.writeOptionalCollection(filtering); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(connectorId, request.connectorId) && Objects.equals(filtering, request.filtering); + } + + @Override + public int hashCode() { + return Objects.hash(connectorId, filtering); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + final DocWriteResponse.Result result; + + public Response(StreamInput in) throws IOException { + super(in); + result = DocWriteResponse.Result.readFrom(in); + } + + public Response(DocWriteResponse.Result result) { + this.result = result; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + this.result.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("result", this.result.getLowercase()); + builder.endObject(); + return builder; + } + + public RestStatus status() { + return switch (result) { + case NOT_FOUND -> RestStatus.NOT_FOUND; + default -> RestStatus.OK; + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response that = (Response) o; + return Objects.equals(result, that.result); + } + + @Override + public int hashCode() { + return Objects.hash(result); + } + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineAction.java new file mode 100644 index 0000000000000..68babb2d4b517 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineAction.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.application.connector.Connector; +import org.elasticsearch.xpack.application.connector.ConnectorIngestPipeline; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +public class UpdateConnectorPipelineAction extends ActionType { + + public static final UpdateConnectorPipelineAction INSTANCE = new UpdateConnectorPipelineAction(); + public static final String NAME = "cluster:admin/xpack/connector/update_pipeline"; + + public UpdateConnectorPipelineAction() { + super(NAME, UpdateConnectorPipelineAction.Response::new); + } + + public static class Request extends ActionRequest implements ToXContentObject { + + private final String connectorId; + private final ConnectorIngestPipeline pipeline; + + public Request(String connectorId, ConnectorIngestPipeline pipeline) { + this.connectorId = connectorId; + this.pipeline = pipeline; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.connectorId = in.readString(); + this.pipeline = in.readOptionalWriteable(ConnectorIngestPipeline::new); + } + + public String getConnectorId() { + return connectorId; + } + + public ConnectorIngestPipeline getPipeline() { + return pipeline; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + + if (Strings.isNullOrEmpty(connectorId)) { + validationException = addValidationError("[connector_id] cannot be null or empty.", validationException); + } + + if (Objects.isNull(pipeline)) { + validationException = addValidationError("[pipeline] cannot be null.", validationException); + } + + return validationException; + } + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "connector_update_pipeline_request", + false, + ((args, connectorId) -> new UpdateConnectorPipelineAction.Request(connectorId, (ConnectorIngestPipeline) args[0])) + ); + + static { + PARSER.declareObject(constructorArg(), (p, c) -> ConnectorIngestPipeline.fromXContent(p), Connector.PIPELINE_FIELD); + } + + public static UpdateConnectorPipelineAction.Request fromXContentBytes( + String connectorId, + BytesReference source, + XContentType xContentType + ) { + try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { + return UpdateConnectorPipelineAction.Request.fromXContent(parser, connectorId); + } catch (IOException e) { + throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); + } + } + + public static UpdateConnectorPipelineAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { + return PARSER.parse(parser, connectorId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + builder.field(Connector.PIPELINE_FIELD.getPreferredName(), pipeline); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(connectorId); + out.writeOptionalWriteable(pipeline); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(connectorId, request.connectorId) && Objects.equals(pipeline, request.pipeline); + } + + @Override + public int hashCode() { + return Objects.hash(connectorId, pipeline); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + final DocWriteResponse.Result result; + + public Response(StreamInput in) throws IOException { + super(in); + result = DocWriteResponse.Result.readFrom(in); + } + + public Response(DocWriteResponse.Result result) { + this.result = result; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + this.result.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("result", this.result.getLowercase()); + builder.endObject(); + return builder; + } + + public RestStatus status() { + return switch (result) { + case NOT_FOUND -> RestStatus.NOT_FOUND; + default -> RestStatus.OK; + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response that = (Response) o; + return Objects.equals(result, that.result); + } + + @Override + public int hashCode() { + return Objects.hash(result); + } + + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorSchedulingAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorSchedulingAction.java index eb0e265c44f28..9867830c5d211 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorSchedulingAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorSchedulingAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -30,6 +31,7 @@ import java.io.IOException; import java.util.Objects; +import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; public class UpdateConnectorSchedulingAction extends ActionType { @@ -67,7 +69,17 @@ public ConnectorScheduling getScheduling() { @Override public ActionRequestValidationException validate() { - return null; + ActionRequestValidationException validationException = null; + + if (Strings.isNullOrEmpty(connectorId)) { + validationException = addValidationError("[connector_id] cannot be null or empty.", validationException); + } + + if (Objects.isNull(scheduling)) { + validationException = addValidationError("[scheduling] cannot be null.", validationException); + } + + return validationException; } private static final ConstructingObjectParser PARSER = diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobConstants.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobConstants.java new file mode 100644 index 0000000000000..cf44ab4e733c8 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobConstants.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob; + +import static org.elasticsearch.xpack.application.connector.syncjob.action.DeleteConnectorSyncJobAction.Request.CONNECTOR_SYNC_JOB_ID_FIELD; + +public class ConnectorSyncJobConstants { + + public static final String EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE = + "[connector_sync_job_id] of the connector sync job cannot be null or empty."; + public static final String CONNECTOR_SYNC_JOB_ID_PARAM = CONNECTOR_SYNC_JOB_ID_FIELD.getPreferredName(); + +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java index 5deb63fd60669..ab593fe99fcee 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java @@ -19,15 +19,19 @@ import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.common.UUIDs; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.engine.DocumentMissingException; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorFiltering; import org.elasticsearch.xpack.application.connector.ConnectorIndexService; import org.elasticsearch.xpack.application.connector.ConnectorIngestPipeline; +import org.elasticsearch.xpack.application.connector.ConnectorSyncStatus; import org.elasticsearch.xpack.application.connector.ConnectorTemplateRegistry; import org.elasticsearch.xpack.application.connector.syncjob.action.PostConnectorSyncJobAction; @@ -114,6 +118,102 @@ public void createConnectorSyncJob( } } + /** + * Deletes the {@link ConnectorSyncJob} in the underlying index. + * + * @param connectorSyncJobId The id of the connector sync job object. + * @param listener The action listener to invoke on response/failure. + */ + public void deleteConnectorSyncJob(String connectorSyncJobId, ActionListener listener) { + final DeleteRequest deleteRequest = new DeleteRequest(CONNECTOR_SYNC_JOB_INDEX_NAME).id(connectorSyncJobId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + try { + clientWithOrigin.delete( + deleteRequest, + new DelegatingIndexNotFoundOrDocumentMissingActionListener<>(connectorSyncJobId, listener, (l, deleteResponse) -> { + if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { + l.onFailure(new ResourceNotFoundException(connectorSyncJobId)); + return; + } + l.onResponse(deleteResponse); + }) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Checks in the {@link ConnectorSyncJob} in the underlying index. + * In this context "checking in" means to update the "last_seen" timestamp to the time, when the method was called. + * + * @param connectorSyncJobId The id of the connector sync job object. + * @param listener The action listener to invoke on response/failure. + */ + public void checkInConnectorSyncJob(String connectorSyncJobId, ActionListener listener) { + Instant newLastSeen = Instant.now(); + + final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_SYNC_JOB_INDEX_NAME, connectorSyncJobId).setRefreshPolicy( + WriteRequest.RefreshPolicy.IMMEDIATE + ).doc(Map.of(ConnectorSyncJob.LAST_SEEN_FIELD.getPreferredName(), newLastSeen)); + + try { + clientWithOrigin.update( + updateRequest, + new DelegatingIndexNotFoundOrDocumentMissingActionListener<>(connectorSyncJobId, listener, (l, updateResponse) -> { + if (updateResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { + l.onFailure(new ResourceNotFoundException(connectorSyncJobId)); + return; + } + l.onResponse(updateResponse); + }) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Cancels the {@link ConnectorSyncJob} in the underlying index. + * Canceling means to set the {@link ConnectorSyncStatus} to "canceling" and not "canceled" as this is an async operation. + * It also updates 'cancelation_requested_at' to the time, when the method was called. + * + * @param connectorSyncJobId The id of the connector sync job object. + * @param listener The action listener to invoke on response/failure. + */ + public void cancelConnectorSyncJob(String connectorSyncJobId, ActionListener listener) { + Instant cancellationRequestedAt = Instant.now(); + + final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_SYNC_JOB_INDEX_NAME, connectorSyncJobId).setRefreshPolicy( + WriteRequest.RefreshPolicy.IMMEDIATE + ) + .doc( + Map.of( + ConnectorSyncJob.STATUS_FIELD.getPreferredName(), + ConnectorSyncStatus.CANCELING, + ConnectorSyncJob.CANCELATION_REQUESTED_AT_FIELD.getPreferredName(), + cancellationRequestedAt + ) + ); + + try { + clientWithOrigin.update( + updateRequest, + new DelegatingIndexNotFoundOrDocumentMissingActionListener<>(connectorSyncJobId, listener, (l, updateResponse) -> { + if (updateResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { + l.onFailure(new ResourceNotFoundException(connectorSyncJobId)); + return; + } + l.onResponse(updateResponse); + }) + ); + } catch (Exception e) { + listener.onFailure(e); + } + + } + private String generateId() { /* Workaround: only needed for generating an id upfront, autoGenerateId() has a side effect generating a timestamp, * which would raise an error on the response layer later ("autoGeneratedTimestamp should not be set externally"). @@ -165,41 +265,19 @@ public void onFailure(Exception e) { } /** - * Deletes the {@link ConnectorSyncJob} in the underlying index. - * - * @param connectorSyncJobId The id of the connector sync job object. - * @param listener The action listener to invoke on response/failure. - */ - public void deleteConnectorSyncJob(String connectorSyncJobId, ActionListener listener) { - final DeleteRequest deleteRequest = new DeleteRequest(CONNECTOR_SYNC_JOB_INDEX_NAME).id(connectorSyncJobId) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - try { - clientWithOrigin.delete( - deleteRequest, - new DelegatingIndexNotFoundActionListener<>(connectorSyncJobId, listener, (l, deleteResponse) -> { - if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { - l.onFailure(new ResourceNotFoundException(connectorSyncJobId)); - return; - } - l.onResponse(deleteResponse); - }) - ); - } catch (Exception e) { - listener.onFailure(e); - } - } - - /** - * Listeners that checks failures for IndexNotFoundException, and transforms them in ResourceNotFoundException, - * invoking onFailure on the delegate listener + * Listeners that checks failures for IndexNotFoundException and DocumentMissingException, + * and transforms them in ResourceNotFoundException, invoking onFailure on the delegate listener. */ - static class DelegatingIndexNotFoundActionListener extends DelegatingActionListener { + static class DelegatingIndexNotFoundOrDocumentMissingActionListener extends DelegatingActionListener { private final BiConsumer, T> bc; private final String connectorSyncJobId; - DelegatingIndexNotFoundActionListener(String connectorSyncJobId, ActionListener delegate, BiConsumer, T> bc) { + DelegatingIndexNotFoundOrDocumentMissingActionListener( + String connectorSyncJobId, + ActionListener delegate, + BiConsumer, T> bc + ) { super(delegate); this.bc = bc; this.connectorSyncJobId = connectorSyncJobId; @@ -213,7 +291,7 @@ public void onResponse(T t) { @Override public void onFailure(Exception e) { Throwable cause = ExceptionsHelper.unwrapCause(e); - if (cause instanceof IndexNotFoundException) { + if (cause instanceof IndexNotFoundException || cause instanceof DocumentMissingException) { delegate.onFailure(new ResourceNotFoundException("connector sync job [" + connectorSyncJobId + "] not found")); return; } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobAction.java new file mode 100644 index 0000000000000..7179bbb3a62f2 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobAction.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE; + +public class CancelConnectorSyncJobAction extends ActionType { + + public static final CancelConnectorSyncJobAction INSTANCE = new CancelConnectorSyncJobAction(); + public static final String NAME = "cluster:admin/xpack/connector/sync_job/cancel"; + + private CancelConnectorSyncJobAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends ActionRequest implements ToXContentObject { + public static final ParseField CONNECTOR_SYNC_JOB_ID_FIELD = new ParseField("connector_sync_job_id"); + + private final String connectorSyncJobId; + + public Request(StreamInput in) throws IOException { + super(in); + this.connectorSyncJobId = in.readString(); + } + + public Request(String connectorSyncJobId) { + this.connectorSyncJobId = connectorSyncJobId; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + + if (Strings.isNullOrEmpty(connectorSyncJobId)) { + validationException = addValidationError(EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE, validationException); + } + + return validationException; + } + + public String getConnectorSyncJobId() { + return connectorSyncJobId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(connectorSyncJobId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(connectorSyncJobId, request.connectorSyncJobId); + } + + @Override + public int hashCode() { + return Objects.hash(connectorSyncJobId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CONNECTOR_SYNC_JOB_ID_FIELD.getPreferredName(), connectorSyncJobId); + builder.endObject(); + return builder; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "cancel_connector_sync_job_request", + false, + (args) -> new Request((String) args[0]) + ); + + static { + PARSER.declareString(constructorArg(), CONNECTOR_SYNC_JOB_ID_FIELD); + } + + public static CancelConnectorSyncJobAction.Request parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } + +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobAction.java new file mode 100644 index 0000000000000..3e5e1578cd54d --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobAction.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +public class CheckInConnectorSyncJobAction extends ActionType { + + public static final CheckInConnectorSyncJobAction INSTANCE = new CheckInConnectorSyncJobAction(); + public static final String NAME = "cluster:admin/xpack/connector/sync_job/check_in"; + + private CheckInConnectorSyncJobAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends ActionRequest implements ToXContentObject { + public static final ParseField CONNECTOR_SYNC_JOB_ID_FIELD = new ParseField("connector_sync_job_id"); + private final String connectorSyncJobId; + + public Request(StreamInput in) throws IOException { + super(in); + this.connectorSyncJobId = in.readString(); + } + + public Request(String connectorSyncJobId) { + this.connectorSyncJobId = connectorSyncJobId; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + + if (Strings.isNullOrEmpty(connectorSyncJobId)) { + validationException = addValidationError( + ConnectorSyncJobConstants.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE, + validationException + ); + } + + return validationException; + } + + public String getConnectorSyncJobId() { + return connectorSyncJobId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(connectorSyncJobId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(connectorSyncJobId, request.connectorSyncJobId); + } + + @Override + public int hashCode() { + return Objects.hash(connectorSyncJobId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CONNECTOR_SYNC_JOB_ID_FIELD.getPreferredName(), connectorSyncJobId); + builder.endObject(); + return builder; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "check_in_connector_sync_job_request", + false, + (args) -> new Request((String) args[0]) + ); + + static { + PARSER.declareString(constructorArg(), CONNECTOR_SYNC_JOB_ID_FIELD); + } + + public static CheckInConnectorSyncJobAction.Request parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobAction.java index 147f8784a8ec7..05cd6cce90fdd 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants; import java.io.IOException; import java.util.Objects; @@ -36,8 +37,6 @@ private DeleteConnectorSyncJobAction() { } public static class Request extends ActionRequest implements ToXContentObject { - public static final String EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE = - "[connector_sync_job_id] of the connector sync job cannot be null or empty."; public static final ParseField CONNECTOR_SYNC_JOB_ID_FIELD = new ParseField("connector_sync_job_id"); private final String connectorSyncJobId; @@ -56,7 +55,10 @@ public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; if (Strings.isNullOrEmpty(connectorSyncJobId)) { - validationException = addValidationError(EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE, validationException); + validationException = addValidationError( + ConnectorSyncJobConstants.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE, + validationException + ); } return validationException; diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestCancelConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestCancelConnectorSyncJobAction.java new file mode 100644 index 0000000000000..82d679c6f0ad0 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestCancelConnectorSyncJobAction.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.application.EnterpriseSearch; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.application.connector.syncjob.action.DeleteConnectorSyncJobAction.Request.CONNECTOR_SYNC_JOB_ID_FIELD; + +public class RestCancelConnectorSyncJobAction extends BaseRestHandler { + + private static final String CONNECTOR_SYNC_JOB_ID_PARAM = CONNECTOR_SYNC_JOB_ID_FIELD.getPreferredName(); + + @Override + public String getName() { + return "connector_sync_job_cancel_action"; + } + + @Override + public List routes() { + return List.of( + new Route( + RestRequest.Method.PUT, + "/" + EnterpriseSearch.CONNECTOR_SYNC_JOB_API_ENDPOINT + "/{" + CONNECTOR_SYNC_JOB_ID_PARAM + "}/_cancel" + ) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + CancelConnectorSyncJobAction.Request request = new CancelConnectorSyncJobAction.Request( + restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM) + ); + return restChannel -> client.execute(CancelConnectorSyncJobAction.INSTANCE, request, new RestToXContentListener<>(restChannel)); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestCheckInConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestCheckInConnectorSyncJobAction.java new file mode 100644 index 0000000000000..86f97f4c5fdb4 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestCheckInConnectorSyncJobAction.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.application.EnterpriseSearch; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants.CONNECTOR_SYNC_JOB_ID_PARAM; + +public class RestCheckInConnectorSyncJobAction extends BaseRestHandler { + + @Override + public String getName() { + return "connector_sync_job_check_in_action"; + } + + @Override + public List routes() { + return List.of( + new Route( + RestRequest.Method.PUT, + "/" + EnterpriseSearch.CONNECTOR_SYNC_JOB_API_ENDPOINT + "/{" + CONNECTOR_SYNC_JOB_ID_PARAM + "}/_check_in" + ) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + CheckInConnectorSyncJobAction.Request request = new CheckInConnectorSyncJobAction.Request( + restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM) + ); + + return restChannel -> client.execute(CheckInConnectorSyncJobAction.INSTANCE, request, new RestToXContentListener<>(restChannel)); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestDeleteConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestDeleteConnectorSyncJobAction.java index 283675f89d1db..c1f352a341cc3 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestDeleteConnectorSyncJobAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestDeleteConnectorSyncJobAction.java @@ -16,12 +16,10 @@ import java.io.IOException; import java.util.List; -import static org.elasticsearch.xpack.application.connector.syncjob.action.DeleteConnectorSyncJobAction.Request.CONNECTOR_SYNC_JOB_ID_FIELD; +import static org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants.CONNECTOR_SYNC_JOB_ID_PARAM; public class RestDeleteConnectorSyncJobAction extends BaseRestHandler { - private static final String CONNECTOR_SYNC_JOB_ID_PARAM = CONNECTOR_SYNC_JOB_ID_FIELD.getPreferredName(); - @Override public String getName() { return "connector_sync_job_delete_action"; diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCancelConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCancelConnectorSyncJobAction.java new file mode 100644 index 0000000000000..ac61dcdf08a61 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCancelConnectorSyncJobAction.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobIndexService; + +public class TransportCancelConnectorSyncJobAction extends HandledTransportAction< + CancelConnectorSyncJobAction.Request, + AcknowledgedResponse> { + + protected ConnectorSyncJobIndexService connectorSyncJobIndexService; + + @Inject + public TransportCancelConnectorSyncJobAction( + TransportService transportService, + ClusterService clusterService, + ActionFilters actionFilters, + Client client + ) { + super( + CancelConnectorSyncJobAction.NAME, + transportService, + actionFilters, + CancelConnectorSyncJobAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.connectorSyncJobIndexService = new ConnectorSyncJobIndexService(client); + } + + @Override + protected void doExecute(Task task, CancelConnectorSyncJobAction.Request request, ActionListener listener) { + connectorSyncJobIndexService.cancelConnectorSyncJob(request.getConnectorSyncJobId(), listener.map(r -> AcknowledgedResponse.TRUE)); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCheckInConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCheckInConnectorSyncJobAction.java new file mode 100644 index 0000000000000..ebaadc80f4c27 --- /dev/null +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCheckInConnectorSyncJobAction.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobIndexService; + +public class TransportCheckInConnectorSyncJobAction extends HandledTransportAction< + CheckInConnectorSyncJobAction.Request, + AcknowledgedResponse> { + + protected final ConnectorSyncJobIndexService connectorSyncJobIndexService; + + @Inject + public TransportCheckInConnectorSyncJobAction( + TransportService transportService, + ClusterService clusterService, + ActionFilters actionFilters, + Client client + ) { + super( + CheckInConnectorSyncJobAction.NAME, + transportService, + actionFilters, + CheckInConnectorSyncJobAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.connectorSyncJobIndexService = new ConnectorSyncJobIndexService(client); + } + + @Override + protected void doExecute(Task task, CheckInConnectorSyncJobAction.Request request, ActionListener listener) { + connectorSyncJobIndexService.checkInConnectorSyncJob(request.getConnectorSyncJobId(), listener.map(r -> AcknowledgedResponse.TRUE)); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java index 5d0d539262f10..5f32f27b1ec64 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java @@ -14,6 +14,8 @@ import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.application.connector.action.UpdateConnectorFilteringAction; +import org.elasticsearch.xpack.application.connector.action.UpdateConnectorPipelineAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorSchedulingAction; import org.junit.Before; @@ -22,6 +24,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.hamcrest.CoreMatchers.anyOf; import static org.hamcrest.CoreMatchers.equalTo; @@ -61,6 +65,49 @@ public void testDeleteConnector() throws Exception { expectThrows(ResourceNotFoundException.class, () -> awaitDeleteConnector(connectorIdToDelete)); } + public void testUpdateConnectorPipeline() throws Exception { + Connector connector = ConnectorTestUtils.getRandomConnector(); + DocWriteResponse resp = awaitPutConnector(connector); + assertThat(resp.status(), anyOf(equalTo(RestStatus.CREATED), equalTo(RestStatus.OK))); + + ConnectorIngestPipeline updatedPipeline = new ConnectorIngestPipeline.Builder().setName("test-pipeline") + .setExtractBinaryContent(false) + .setReduceWhitespace(true) + .setRunMlInference(false) + .build(); + + UpdateConnectorPipelineAction.Request updatePipelineRequest = new UpdateConnectorPipelineAction.Request( + connector.getConnectorId(), + updatedPipeline + ); + + DocWriteResponse updateResponse = awaitUpdateConnectorPipeline(updatePipelineRequest); + assertThat(updateResponse.status(), equalTo(RestStatus.OK)); + Connector indexedConnector = awaitGetConnector(connector.getConnectorId()); + assertThat(updatedPipeline, equalTo(indexedConnector.getPipeline())); + } + + public void testUpdateConnectorFiltering() throws Exception { + Connector connector = ConnectorTestUtils.getRandomConnector(); + + DocWriteResponse resp = awaitPutConnector(connector); + assertThat(resp.status(), anyOf(equalTo(RestStatus.CREATED), equalTo(RestStatus.OK))); + + List filteringList = IntStream.range(0, 10) + .mapToObj((i) -> ConnectorTestUtils.getRandomConnectorFiltering()) + .collect(Collectors.toList()); + + UpdateConnectorFilteringAction.Request updateFilteringRequest = new UpdateConnectorFilteringAction.Request( + connector.getConnectorId(), + filteringList + ); + + DocWriteResponse updateResponse = awaitUpdateConnectorFiltering(updateFilteringRequest); + assertThat(updateResponse.status(), equalTo(RestStatus.OK)); + Connector indexedConnector = awaitGetConnector(connector.getConnectorId()); + assertThat(filteringList, equalTo(indexedConnector.getFiltering())); + } + public void testUpdateConnectorScheduling() throws Exception { Connector connector = ConnectorTestUtils.getRandomConnector(); DocWriteResponse resp = awaitPutConnector(connector); @@ -180,6 +227,56 @@ public void onFailure(Exception e) { return resp.get(); } + private UpdateResponse awaitUpdateConnectorFiltering(UpdateConnectorFilteringAction.Request updateFiltering) throws Exception { + CountDownLatch latch = new CountDownLatch(1); + final AtomicReference resp = new AtomicReference<>(null); + final AtomicReference exc = new AtomicReference<>(null); + connectorIndexService.updateConnectorFiltering(updateFiltering, new ActionListener<>() { + @Override + public void onResponse(UpdateResponse indexResponse) { + resp.set(indexResponse); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + exc.set(e); + latch.countDown(); + } + }); + assertTrue("Timeout waiting for update filtering request", latch.await(REQUEST_TIMEOUT_SECONDS, TimeUnit.SECONDS)); + if (exc.get() != null) { + throw exc.get(); + } + assertNotNull("Received null response from update filtering request", resp.get()); + return resp.get(); + } + + private UpdateResponse awaitUpdateConnectorPipeline(UpdateConnectorPipelineAction.Request updatePipeline) throws Exception { + CountDownLatch latch = new CountDownLatch(1); + final AtomicReference resp = new AtomicReference<>(null); + final AtomicReference exc = new AtomicReference<>(null); + connectorIndexService.updateConnectorPipeline(updatePipeline, new ActionListener<>() { + @Override + public void onResponse(UpdateResponse indexResponse) { + resp.set(indexResponse); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + exc.set(e); + latch.countDown(); + } + }); + assertTrue("Timeout waiting for update pipeline request", latch.await(REQUEST_TIMEOUT_SECONDS, TimeUnit.SECONDS)); + if (exc.get() != null) { + throw exc.get(); + } + assertNotNull("Received null response from update pipeline request", resp.get()); + return resp.get(); + } + private UpdateResponse awaitUpdateConnectorScheduling(UpdateConnectorSchedulingAction.Request updatedScheduling) throws Exception { CountDownLatch latch = new CountDownLatch(1); final AtomicReference resp = new AtomicReference<>(null); diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java new file mode 100644 index 0000000000000..1d433d58be6ad --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.application.connector.ConnectorTestUtils; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; + +import java.io.IOException; +import java.util.List; + +public class UpdateConnectorFilteringActionRequestBWCSerializingTests extends AbstractBWCSerializationTestCase< + UpdateConnectorFilteringAction.Request> { + + private String connectorId; + + @Override + protected Writeable.Reader instanceReader() { + return UpdateConnectorFilteringAction.Request::new; + } + + @Override + protected UpdateConnectorFilteringAction.Request createTestInstance() { + this.connectorId = randomUUID(); + return new UpdateConnectorFilteringAction.Request( + connectorId, + List.of(ConnectorTestUtils.getRandomConnectorFiltering(), ConnectorTestUtils.getRandomConnectorFiltering()) + ); + } + + @Override + protected UpdateConnectorFilteringAction.Request mutateInstance(UpdateConnectorFilteringAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected UpdateConnectorFilteringAction.Request doParseInstance(XContentParser parser) throws IOException { + return UpdateConnectorFilteringAction.Request.fromXContent(parser, this.connectorId); + } + + @Override + protected UpdateConnectorFilteringAction.Request mutateInstanceForVersion( + UpdateConnectorFilteringAction.Request instance, + TransportVersion version + ) { + return instance; + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionResponseBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionResponseBWCSerializingTests.java new file mode 100644 index 0000000000000..0f33eeac8dfb5 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionResponseBWCSerializingTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +public class UpdateConnectorFilteringActionResponseBWCSerializingTests extends AbstractBWCWireSerializationTestCase< + UpdateConnectorFilteringAction.Response> { + + @Override + protected Writeable.Reader instanceReader() { + return UpdateConnectorFilteringAction.Response::new; + } + + @Override + protected UpdateConnectorFilteringAction.Response createTestInstance() { + return new UpdateConnectorFilteringAction.Response(randomFrom(DocWriteResponse.Result.values())); + } + + @Override + protected UpdateConnectorFilteringAction.Response mutateInstance(UpdateConnectorFilteringAction.Response instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected UpdateConnectorFilteringAction.Response mutateInstanceForVersion( + UpdateConnectorFilteringAction.Response instance, + TransportVersion version + ) { + return instance; + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineActionRequestBWCSerializingTests.java new file mode 100644 index 0000000000000..14df1b704f995 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineActionRequestBWCSerializingTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.application.connector.ConnectorTestUtils; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; + +import java.io.IOException; + +public class UpdateConnectorPipelineActionRequestBWCSerializingTests extends AbstractBWCSerializationTestCase< + UpdateConnectorPipelineAction.Request> { + + private String connectorId; + + @Override + protected Writeable.Reader instanceReader() { + return UpdateConnectorPipelineAction.Request::new; + } + + @Override + protected UpdateConnectorPipelineAction.Request createTestInstance() { + this.connectorId = randomUUID(); + return new UpdateConnectorPipelineAction.Request(connectorId, ConnectorTestUtils.getRandomConnectorIngestPipeline()); + } + + @Override + protected UpdateConnectorPipelineAction.Request mutateInstance(UpdateConnectorPipelineAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected UpdateConnectorPipelineAction.Request doParseInstance(XContentParser parser) throws IOException { + return UpdateConnectorPipelineAction.Request.fromXContent(parser, this.connectorId); + } + + @Override + protected UpdateConnectorPipelineAction.Request mutateInstanceForVersion( + UpdateConnectorPipelineAction.Request instance, + TransportVersion version + ) { + return instance; + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineActionResponseBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineActionResponseBWCSerializingTests.java new file mode 100644 index 0000000000000..065dafcaf00a4 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineActionResponseBWCSerializingTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +public class UpdateConnectorPipelineActionResponseBWCSerializingTests extends AbstractBWCWireSerializationTestCase< + UpdateConnectorPipelineAction.Response> { + @Override + protected Writeable.Reader instanceReader() { + return UpdateConnectorPipelineAction.Response::new; + } + + @Override + protected UpdateConnectorPipelineAction.Response createTestInstance() { + return new UpdateConnectorPipelineAction.Response(randomFrom(DocWriteResponse.Result.values())); + } + + @Override + protected UpdateConnectorPipelineAction.Response mutateInstance(UpdateConnectorPipelineAction.Response instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected UpdateConnectorPipelineAction.Response mutateInstanceForVersion( + UpdateConnectorPipelineAction.Response instance, + TransportVersion version + ) { + return instance; + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java index 9ac1f4935c6cc..cadc8b761cbe3 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java @@ -17,8 +17,10 @@ import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorIndexService; @@ -28,22 +30,32 @@ import org.junit.Before; import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; public class ConnectorSyncJobIndexServiceTests extends ESSingleNodeTestCase { private static final String NON_EXISTING_CONNECTOR_ID = "non-existing-connector-id"; + private static final String NON_EXISTING_SYNC_JOB_ID = "non-existing-sync-job-id"; + private static final String LAST_SEEN_FIELD_NAME = ConnectorSyncJob.LAST_SEEN_FIELD.getPreferredName(); private static final int TIMEOUT_SECONDS = 10; + private static final int ONE_SECOND_IN_MILLIS = 1000; private ConnectorSyncJobIndexService connectorSyncJobIndexService; private Connector connector; @@ -169,7 +181,161 @@ public void testDeleteConnectorSyncJob() throws Exception { } public void testDeleteConnectorSyncJob_WithMissingSyncJobId_ExpectException() { - expectThrows(ResourceNotFoundException.class, () -> awaitDeleteConnectorSyncJob("non-existing-sync-job-id")); + expectThrows(ResourceNotFoundException.class, () -> awaitDeleteConnectorSyncJob(NON_EXISTING_SYNC_JOB_ID)); + } + + public void testCheckInConnectorSyncJob() throws Exception { + PostConnectorSyncJobAction.Request syncJobRequest = ConnectorSyncJobTestUtils.getRandomPostConnectorSyncJobActionRequest( + connector.getConnectorId() + ); + PostConnectorSyncJobAction.Response response = awaitPutConnectorSyncJob(syncJobRequest); + String syncJobId = response.getId(); + + Map syncJobSourceBeforeUpdate = getConnectorSyncJobSourceById(syncJobId); + Instant lastSeenBeforeUpdate = Instant.parse((String) syncJobSourceBeforeUpdate.get(LAST_SEEN_FIELD_NAME)); + + safeSleep(ONE_SECOND_IN_MILLIS); + + UpdateResponse updateResponse = awaitCheckInConnectorSyncJob(syncJobId); + Map syncJobSourceAfterUpdate = getConnectorSyncJobSourceById(syncJobId); + Instant lastSeenAfterUpdate = Instant.parse((String) syncJobSourceAfterUpdate.get(LAST_SEEN_FIELD_NAME)); + long secondsBetweenLastSeenBeforeAndAfterUpdate = ChronoUnit.SECONDS.between(lastSeenBeforeUpdate, lastSeenAfterUpdate); + + assertThat("Wrong sync job was updated", syncJobId, equalTo(updateResponse.getId())); + assertThat(updateResponse.status(), equalTo(RestStatus.OK)); + assertTrue( + "[" + LAST_SEEN_FIELD_NAME + "] after the check in is not after [" + LAST_SEEN_FIELD_NAME + "] before the check in", + lastSeenAfterUpdate.isAfter(lastSeenBeforeUpdate) + ); + assertThat( + "there must be at least one second between [" + + LAST_SEEN_FIELD_NAME + + "] after the check in and [" + + LAST_SEEN_FIELD_NAME + + "] before the check in", + secondsBetweenLastSeenBeforeAndAfterUpdate, + greaterThanOrEqualTo(1L) + ); + assertFieldsExceptLastSeenDidNotUpdate(syncJobSourceBeforeUpdate, syncJobSourceAfterUpdate); + } + + public void testCheckInConnectorSyncJob_WithMissingSyncJobId_ExpectException() { + expectThrows(ResourceNotFoundException.class, () -> awaitCheckInConnectorSyncJob(NON_EXISTING_SYNC_JOB_ID)); + } + + public void testCancelConnectorSyncJob() throws Exception { + PostConnectorSyncJobAction.Request syncJobRequest = ConnectorSyncJobTestUtils.getRandomPostConnectorSyncJobActionRequest( + connector.getConnectorId() + ); + PostConnectorSyncJobAction.Response response = awaitPutConnectorSyncJob(syncJobRequest); + String syncJobId = response.getId(); + Map syncJobSourceBeforeUpdate = getConnectorSyncJobSourceById(syncJobId); + ConnectorSyncStatus syncStatusBeforeUpdate = ConnectorSyncStatus.fromString( + (String) syncJobSourceBeforeUpdate.get(ConnectorSyncJob.STATUS_FIELD.getPreferredName()) + ); + Object cancellationRequestedAtBeforeUpdate = syncJobSourceBeforeUpdate.get( + ConnectorSyncJob.CANCELATION_REQUESTED_AT_FIELD.getPreferredName() + ); + + assertThat(syncJobId, notNullValue()); + assertThat(cancellationRequestedAtBeforeUpdate, nullValue()); + assertThat(syncStatusBeforeUpdate, not(equalTo(ConnectorSyncStatus.CANCELING))); + + UpdateResponse updateResponse = awaitCancelConnectorSyncJob(syncJobId); + + Map syncJobSourceAfterUpdate = getConnectorSyncJobSourceById(syncJobId); + ConnectorSyncStatus syncStatusAfterUpdate = ConnectorSyncStatus.fromString( + (String) syncJobSourceAfterUpdate.get(ConnectorSyncJob.STATUS_FIELD.getPreferredName()) + ); + Instant cancellationRequestedAtAfterUpdate = Instant.parse( + (String) syncJobSourceAfterUpdate.get(ConnectorSyncJob.CANCELATION_REQUESTED_AT_FIELD.getPreferredName()) + ); + + assertThat(updateResponse.status(), equalTo(RestStatus.OK)); + assertThat(cancellationRequestedAtAfterUpdate, notNullValue()); + assertThat(syncStatusAfterUpdate, equalTo(ConnectorSyncStatus.CANCELING)); + assertFieldsExceptSyncStatusAndCancellationRequestedAtDidNotUpdate(syncJobSourceBeforeUpdate, syncJobSourceAfterUpdate); + } + + public void testCancelConnectorSyncJob_WithMissingSyncJobId_ExpectException() { + expectThrows(ResourceNotFoundException.class, () -> awaitCancelConnectorSyncJob(NON_EXISTING_SYNC_JOB_ID)); + } + + private static void assertFieldsExceptSyncStatusAndCancellationRequestedAtDidNotUpdate( + Map syncJobSourceBeforeUpdate, + Map syncJobSourceAfterUpdate + ) { + assertFieldsDidNotUpdateExceptFieldList( + syncJobSourceBeforeUpdate, + syncJobSourceAfterUpdate, + List.of(ConnectorSyncJob.STATUS_FIELD, ConnectorSyncJob.CANCELATION_REQUESTED_AT_FIELD) + ); + } + + private static void assertFieldsExceptLastSeenDidNotUpdate( + Map syncJobSourceBeforeUpdate, + Map syncJobSourceAfterUpdate + ) { + assertFieldsDidNotUpdateExceptFieldList( + syncJobSourceBeforeUpdate, + syncJobSourceAfterUpdate, + List.of(ConnectorSyncJob.LAST_SEEN_FIELD) + ); + } + + private static void assertFieldsDidNotUpdateExceptFieldList( + Map syncJobSourceBeforeUpdate, + Map syncJobSourceAfterUpdate, + List fieldsWhichShouldUpdate + ) { + Set fieldsNamesWhichShouldUpdate = fieldsWhichShouldUpdate.stream() + .map(ParseField::getPreferredName) + .collect(Collectors.toSet()); + + for (Map.Entry field : syncJobSourceBeforeUpdate.entrySet()) { + String fieldName = field.getKey(); + boolean isFieldWhichShouldNotUpdate = fieldsNamesWhichShouldUpdate.contains(fieldName) == false; + + if (isFieldWhichShouldNotUpdate) { + Object fieldValueBeforeUpdate = field.getValue(); + Object fieldValueAfterUpdate = syncJobSourceAfterUpdate.get(fieldName); + + assertThat( + "Every field except [" + + String.join(",", fieldsNamesWhichShouldUpdate) + + "] should stay the same. [" + + fieldName + + "] did change.", + fieldValueBeforeUpdate, + equalTo(fieldValueAfterUpdate) + ); + } + } + } + + private UpdateResponse awaitCancelConnectorSyncJob(String syncJobId) throws Exception { + CountDownLatch latch = new CountDownLatch(1); + final AtomicReference resp = new AtomicReference<>(null); + final AtomicReference exc = new AtomicReference<>(null); + connectorSyncJobIndexService.cancelConnectorSyncJob(syncJobId, new ActionListener<>() { + @Override + public void onResponse(UpdateResponse updateResponse) { + resp.set(updateResponse); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + exc.set(e); + latch.countDown(); + } + }); + assertTrue("Timeout waiting for cancel request", latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS)); + if (exc.get() != null) { + throw exc.get(); + } + assertNotNull("Received null response from cancel request", resp.get()); + return resp.get(); } private Map getConnectorSyncJobSourceById(String syncJobId) throws ExecutionException, InterruptedException, @@ -180,6 +346,31 @@ private Map getConnectorSyncJobSourceById(String syncJobId) thro return getResponseActionFuture.get(TIMEOUT_SECONDS, TimeUnit.SECONDS).getSource(); } + private UpdateResponse awaitCheckInConnectorSyncJob(String connectorSyncJobId) throws Exception { + CountDownLatch latch = new CountDownLatch(1); + final AtomicReference resp = new AtomicReference<>(null); + final AtomicReference exc = new AtomicReference<>(null); + connectorSyncJobIndexService.checkInConnectorSyncJob(connectorSyncJobId, new ActionListener<>() { + @Override + public void onResponse(UpdateResponse updateResponse) { + resp.set(updateResponse); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + exc.set(e); + latch.countDown(); + } + }); + assertTrue("Timeout waiting for check in request", latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS)); + if (exc.get() != null) { + throw exc.get(); + } + assertNotNull("Received null response from check in request", resp.get()); + return resp.get(); + } + private void awaitPutConnectorSyncJobExpectingException( PostConnectorSyncJobAction.Request syncJobRequest, ActionListener listener diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java index 099173735edd2..4fa1b9122284d 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java @@ -9,6 +9,8 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.application.connector.ConnectorTestUtils; +import org.elasticsearch.xpack.application.connector.syncjob.action.CancelConnectorSyncJobAction; +import org.elasticsearch.xpack.application.connector.syncjob.action.CheckInConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.DeleteConnectorSyncJobAction; import org.elasticsearch.xpack.application.connector.syncjob.action.PostConnectorSyncJobAction; @@ -90,4 +92,12 @@ public static PostConnectorSyncJobAction.Request getRandomPostConnectorSyncJobAc public static PostConnectorSyncJobAction.Response getRandomPostConnectorSyncJobActionResponse() { return new PostConnectorSyncJobAction.Response(randomAlphaOfLength(10)); } + + public static CancelConnectorSyncJobAction.Request getRandomCancelConnectorSyncJobActionRequest() { + return new CancelConnectorSyncJobAction.Request(randomAlphaOfLength(10)); + } + + public static CheckInConnectorSyncJobAction.Request getRandomCheckInConnectorSyncJobActionRequest() { + return new CheckInConnectorSyncJobAction.Request(randomAlphaOfLength(10)); + } } diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobActionRequestBWCSerializingTests.java new file mode 100644 index 0000000000000..81f59a130ac70 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobActionRequestBWCSerializingTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; + +import java.io.IOException; + +public class CancelConnectorSyncJobActionRequestBWCSerializingTests extends AbstractBWCSerializationTestCase< + CancelConnectorSyncJobAction.Request> { + @Override + protected Writeable.Reader instanceReader() { + return CancelConnectorSyncJobAction.Request::new; + } + + @Override + protected CancelConnectorSyncJobAction.Request createTestInstance() { + return ConnectorSyncJobTestUtils.getRandomCancelConnectorSyncJobActionRequest(); + } + + @Override + protected CancelConnectorSyncJobAction.Request mutateInstance(CancelConnectorSyncJobAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected CancelConnectorSyncJobAction.Request doParseInstance(XContentParser parser) throws IOException { + return CancelConnectorSyncJobAction.Request.parse(parser); + } + + @Override + protected CancelConnectorSyncJobAction.Request mutateInstanceForVersion( + CancelConnectorSyncJobAction.Request instance, + TransportVersion version + ) { + return new CancelConnectorSyncJobAction.Request(instance.getConnectorSyncJobId()); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobActionTests.java new file mode 100644 index 0000000000000..0dd8d452254dc --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CancelConnectorSyncJobActionTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; + +import static org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class CancelConnectorSyncJobActionTests extends ESTestCase { + + public void testValidate_WhenConnectorSyncJobIdIsPresent_ExpectNoValidationError() { + CancelConnectorSyncJobAction.Request request = ConnectorSyncJobTestUtils.getRandomCancelConnectorSyncJobActionRequest(); + ActionRequestValidationException exception = request.validate(); + + assertThat(exception, nullValue()); + } + + public void testValidate_WhenConnectorSyncJobIdIsEmpty_ExpectValidationError() { + CancelConnectorSyncJobAction.Request requestWithMissingConnectorId = new CancelConnectorSyncJobAction.Request(""); + ActionRequestValidationException exception = requestWithMissingConnectorId.validate(); + + assertThat(exception, notNullValue()); + assertThat(exception.getMessage(), containsString(EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE)); + } + +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobActionRequestBWCSerializingTests.java new file mode 100644 index 0000000000000..63f874b32f37c --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobActionRequestBWCSerializingTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; + +import java.io.IOException; + +public class CheckInConnectorSyncJobActionRequestBWCSerializingTests extends AbstractBWCSerializationTestCase< + CheckInConnectorSyncJobAction.Request> { + @Override + protected Writeable.Reader instanceReader() { + return CheckInConnectorSyncJobAction.Request::new; + } + + @Override + protected CheckInConnectorSyncJobAction.Request createTestInstance() { + return ConnectorSyncJobTestUtils.getRandomCheckInConnectorSyncJobActionRequest(); + } + + @Override + protected CheckInConnectorSyncJobAction.Request mutateInstance(CheckInConnectorSyncJobAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected CheckInConnectorSyncJobAction.Request doParseInstance(XContentParser parser) throws IOException { + return CheckInConnectorSyncJobAction.Request.parse(parser); + } + + @Override + protected CheckInConnectorSyncJobAction.Request mutateInstanceForVersion( + CheckInConnectorSyncJobAction.Request instance, + TransportVersion version + ) { + return new CheckInConnectorSyncJobAction.Request(instance.getConnectorSyncJobId()); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobActionTests.java new file mode 100644 index 0000000000000..fe5046e42f828 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/CheckInConnectorSyncJobActionTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class CheckInConnectorSyncJobActionTests extends ESTestCase { + + public void testValidate_WhenConnectorSyncJobIdIsPresent_ExpectNoValidationError() { + CheckInConnectorSyncJobAction.Request request = ConnectorSyncJobTestUtils.getRandomCheckInConnectorSyncJobActionRequest(); + ActionRequestValidationException exception = request.validate(); + + assertThat(exception, nullValue()); + } + + public void testValidate_WhenConnectorSyncJobIdIsEmpty_ExpectValidationError() { + CheckInConnectorSyncJobAction.Request requestWithMissingConnectorSyncJobId = new CheckInConnectorSyncJobAction.Request(""); + ActionRequestValidationException exception = requestWithMissingConnectorSyncJobId.validate(); + + assertThat(exception, notNullValue()); + assertThat(exception.getMessage(), containsString(ConnectorSyncJobConstants.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE)); + } + +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobActionTests.java index ee79db86152c6..00dff3e83211b 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobActionTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/DeleteConnectorSyncJobActionTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; import static org.hamcrest.Matchers.containsString; @@ -28,7 +29,7 @@ public void testValidate_WhenConnectorSyncJobIdIsEmpty_ExpectValidationError() { ActionRequestValidationException exception = requestWithMissingConnectorId.validate(); assertThat(exception, notNullValue()); - assertThat(exception.getMessage(), containsString(DeleteConnectorSyncJobAction.Request.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE)); + assertThat(exception.getMessage(), containsString(ConnectorSyncJobConstants.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE)); } } diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCancelConnectorSyncJobActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCancelConnectorSyncJobActionTests.java new file mode 100644 index 0000000000000..81c56e3345e28 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCancelConnectorSyncJobActionTests.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; +import org.junit.Before; + +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.mock; + +public class TransportCancelConnectorSyncJobActionTests extends ESSingleNodeTestCase { + + private static final Long TIMEOUT_SECONDS = 10L; + private final ThreadPool threadPool = new TestThreadPool(getClass().getName()); + private TransportCancelConnectorSyncJobAction action; + + @Before + public void setup() { + ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + action = new TransportCancelConnectorSyncJobAction(transportService, clusterService, mock(ActionFilters.class), client()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + ThreadPool.terminate(threadPool, TIMEOUT_SECONDS, TimeUnit.SECONDS); + } + + public void testCancelConnectorSyncJob_ExpectNoWarnings() throws InterruptedException { + CancelConnectorSyncJobAction.Request request = ConnectorSyncJobTestUtils.getRandomCancelConnectorSyncJobActionRequest(); + + executeRequest(request); + + ensureNoWarnings(); + } + + private void executeRequest(CancelConnectorSyncJobAction.Request request) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + action.doExecute(mock(Task.class), request, ActionListener.wrap(response -> latch.countDown(), exception -> latch.countDown())); + + boolean requestTimedOut = latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS); + + assertTrue("Timeout waiting for cancel request", requestTimedOut); + } + +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCheckInConnectorSyncJobActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCheckInConnectorSyncJobActionTests.java new file mode 100644 index 0000000000000..d88a246b6d5e2 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/TransportCheckInConnectorSyncJobActionTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; +import org.junit.Before; + +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.mock; + +public class TransportCheckInConnectorSyncJobActionTests extends ESSingleNodeTestCase { + private static final Long TIMEOUT_SECONDS = 10L; + + private final ThreadPool threadPool = new TestThreadPool(getClass().getName()); + private TransportCheckInConnectorSyncJobAction action; + + @Before + public void setup() { + ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + action = new TransportCheckInConnectorSyncJobAction(transportService, clusterService, mock(ActionFilters.class), client()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + ThreadPool.terminate(threadPool, TIMEOUT_SECONDS, TimeUnit.SECONDS); + } + + public void testCheckInConnectorSyncJob_ExpectNoWarnings() throws InterruptedException { + CheckInConnectorSyncJobAction.Request request = ConnectorSyncJobTestUtils.getRandomCheckInConnectorSyncJobActionRequest(); + + executeRequest(request); + + ensureNoWarnings(); + } + + private void executeRequest(CheckInConnectorSyncJobAction.Request request) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + action.doExecute(mock(Task.class), request, ActionListener.wrap(response -> latch.countDown(), exception -> latch.countDown())); + + boolean requestTimedOut = latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS); + + assertTrue("Timeout waiting for checkin request", requestTimedOut); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AbstractPageMappingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AbstractPageMappingOperator.java index ca4dbccb5b442..5924e4086c743 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AbstractPageMappingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AbstractPageMappingOperator.java @@ -95,7 +95,7 @@ public static class Status implements Operator.Status { private final int pagesProcessed; - protected Status(int pagesProcessed) { + public Status(int pagesProcessed) { this.pagesProcessed = pagesProcessed; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java index 176b2bda31e3e..3e9793ef87b2a 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java @@ -52,6 +52,7 @@ public class Driver implements Releasable, Describable { private final DriverContext driverContext; private final Supplier description; private final List activeOperators; + private final List statusOfCompletedOperators = new ArrayList<>(); private final Releasable releasable; private final long statusNanos; @@ -97,7 +98,9 @@ public Driver( this.activeOperators.add(sink); this.statusNanos = statusInterval.nanos(); this.releasable = releasable; - this.status = new AtomicReference<>(new DriverStatus(sessionId, System.currentTimeMillis(), DriverStatus.Status.QUEUED, List.of())); + this.status = new AtomicReference<>( + new DriverStatus(sessionId, System.currentTimeMillis(), DriverStatus.Status.QUEUED, List.of(), List.of()) + ); } /** @@ -229,7 +232,9 @@ private SubscribableListener runSingleLoopIteration() { List finishedOperators = this.activeOperators.subList(0, index + 1); Iterator itr = finishedOperators.iterator(); while (itr.hasNext()) { - itr.next().close(); + Operator op = itr.next(); + statusOfCompletedOperators.add(new DriverStatus.OperatorStatus(op.toString(), op.status())); + op.close(); itr.remove(); } @@ -394,7 +399,8 @@ private DriverStatus updateStatus(DriverStatus.Status status) { sessionId, System.currentTimeMillis(), status, - activeOperators.stream().map(o -> new DriverStatus.OperatorStatus(o.toString(), o.status())).toList() + statusOfCompletedOperators, + activeOperators.stream().map(op -> new DriverStatus.OperatorStatus(op.toString(), op.status())).toList() ); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverProfile.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverProfile.java new file mode 100644 index 0000000000000..d82ddc1899b1c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverProfile.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.common.xcontent.ChunkedToXContentObject; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * Profile results from a single {@link Driver}. + */ +public class DriverProfile implements Writeable, ChunkedToXContentObject { + /** + * Status of each {@link Operator} in the driver when it finishes. + */ + private final List operators; + + public DriverProfile(List operators) { + this.operators = operators; + } + + public DriverProfile(StreamInput in) throws IOException { + this.operators = in.readCollectionAsImmutableList(DriverStatus.OperatorStatus::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(operators); + } + + List operators() { + return operators; + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.array("operators", operators.iterator()), + ChunkedToXContentHelper.endObject() + ); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DriverProfile that = (DriverProfile) o; + return Objects.equals(operators, that.operators); + } + + @Override + public int hashCode() { + return Objects.hash(operators); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverStatus.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverStatus.java index b3326e395def2..5a6265b37e3c6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverStatus.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverStatus.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.operator; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -46,20 +47,41 @@ public class DriverStatus implements Task.Status { * The state of the overall driver - queue, starting, running, finished. */ private final Status status; + + /** + * Status of each completed {@link Operator} in the driver. + */ + private final List completedOperators; + /** - * Status of each {@link Operator} in the driver. + * Status of each active {@link Operator} in the driver. */ private final List activeOperators; - DriverStatus(String sessionId, long lastUpdated, Status status, List activeOperators) { + DriverStatus( + String sessionId, + long lastUpdated, + Status status, + List completedOperators, + List activeOperators + ) { this.sessionId = sessionId; this.lastUpdated = lastUpdated; this.status = status; + this.completedOperators = completedOperators; this.activeOperators = activeOperators; } - DriverStatus(StreamInput in) throws IOException { - this(in.readString(), in.readLong(), Status.valueOf(in.readString()), in.readCollectionAsImmutableList(OperatorStatus::new)); + public DriverStatus(StreamInput in) throws IOException { + this.sessionId = in.readString(); + this.lastUpdated = in.readLong(); + this.status = Status.valueOf(in.readString()); + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + this.completedOperators = in.readCollectionAsImmutableList(OperatorStatus::new); + } else { + this.completedOperators = List.of(); + } + this.activeOperators = in.readCollectionAsImmutableList(OperatorStatus::new); } @Override @@ -67,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(sessionId); out.writeLong(lastUpdated); out.writeString(status.toString()); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + out.writeCollection(completedOperators); + } out.writeCollection(activeOperators); } @@ -97,7 +122,14 @@ public Status status() { } /** - * Status of each {@link Operator} in the driver. + * Status of each completed {@link Operator} in the driver. + */ + public List completedOperators() { + return completedOperators; + } + + /** + * Status of each active {@link Operator} in the driver. */ public List activeOperators() { return activeOperators; @@ -109,6 +141,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("sessionId", sessionId); builder.field("last_updated", DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.formatMillis(lastUpdated)); builder.field("status", status.toString().toLowerCase(Locale.ROOT)); + builder.startArray("completed_operators"); + for (OperatorStatus completed : completedOperators) { + builder.value(completed); + } + builder.endArray(); builder.startArray("active_operators"); for (OperatorStatus active : activeOperators) { builder.value(active); @@ -125,12 +162,13 @@ public boolean equals(Object o) { return sessionId.equals(that.sessionId) && lastUpdated == that.lastUpdated && status == that.status + && completedOperators.equals(that.completedOperators) && activeOperators.equals(that.activeOperators); } @Override public int hashCode() { - return Objects.hash(sessionId, lastUpdated, status, activeOperators); + return Objects.hash(sessionId, lastUpdated, status, completedOperators, activeOperators); } @Override @@ -153,12 +191,12 @@ public static class OperatorStatus implements Writeable, ToXContentObject { @Nullable private final Operator.Status status; - OperatorStatus(String operator, Operator.Status status) { + public OperatorStatus(String operator, Operator.Status status) { this.operator = operator; this.status = status; } - private OperatorStatus(StreamInput in) throws IOException { + OperatorStatus(StreamInput in) throws IOException { operator = in.readString(); status = in.readOptionalNamedWriteable(Operator.Status.class); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverProfileTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverProfileTests.java new file mode 100644 index 0000000000000..f6b4fbc817940 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverProfileTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.compute.lucene.LuceneSourceOperator; +import org.elasticsearch.compute.lucene.LuceneSourceOperatorStatusTests; +import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; +import org.elasticsearch.compute.lucene.ValuesSourceReaderOperatorStatusTests; +import org.elasticsearch.compute.operator.exchange.ExchangeSinkOperator; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class DriverProfileTests extends AbstractWireSerializingTestCase { + public void testToXContent() { + DriverProfile status = new DriverProfile( + List.of( + new DriverStatus.OperatorStatus("LuceneSource", LuceneSourceOperatorStatusTests.simple()), + new DriverStatus.OperatorStatus("ValuesSourceReader", ValuesSourceReaderOperatorStatusTests.simple()) + ) + ); + assertThat( + Strings.toString(status), + equalTo( + """ + {"operators":[""" + + """ + {"operator":"LuceneSource","status":""" + + LuceneSourceOperatorStatusTests.simpleToJson() + + "},{\"operator\":\"ValuesSourceReader\",\"status\":" + + ValuesSourceReaderOperatorStatusTests.simpleToJson() + + "}]}" + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return DriverProfile::new; + } + + @Override + protected DriverProfile createTestInstance() { + return new DriverProfile(DriverStatusTests.randomOperatorStatuses()); + } + + @Override + protected DriverProfile mutateInstance(DriverProfile instance) throws IOException { + var operators = randomValueOtherThan(instance.operators(), DriverStatusTests::randomOperatorStatuses); + return new DriverProfile(operators); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry( + List.of(LuceneSourceOperator.Status.ENTRY, ValuesSourceReaderOperator.Status.ENTRY, ExchangeSinkOperator.Status.ENTRY) + ); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverStatusTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverStatusTests.java index 775c30223589b..cdae4283540c4 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverStatusTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverStatusTests.java @@ -16,6 +16,8 @@ import org.elasticsearch.compute.lucene.LuceneSourceOperatorStatusTests; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperatorStatusTests; +import org.elasticsearch.compute.operator.exchange.ExchangeSinkOperator; +import org.elasticsearch.compute.operator.exchange.ExchangeSinkOperatorStatusTests; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.ESTestCase; @@ -34,21 +36,18 @@ public void testToXContent() { List.of( new DriverStatus.OperatorStatus("LuceneSource", LuceneSourceOperatorStatusTests.simple()), new DriverStatus.OperatorStatus("ValuesSourceReader", ValuesSourceReaderOperatorStatusTests.simple()) - ) - ); - assertThat( - Strings.toString(status), - equalTo( - """ - {"sessionId":"ABC:123","last_updated":"1973-11-29T09:27:23.214Z","status":"running","active_operators":[""" - + """ - {"operator":"LuceneSource","status":""" - + LuceneSourceOperatorStatusTests.simpleToJson() - + "},{\"operator\":\"ValuesSourceReader\",\"status\":" - + ValuesSourceReaderOperatorStatusTests.simpleToJson() - + "}]}" - ) + ), + List.of(new DriverStatus.OperatorStatus("ExchangeSink", ExchangeSinkOperatorStatusTests.simple())) ); + assertThat(Strings.toString(status), equalTo(""" + {"sessionId":"ABC:123","last_updated":"1973-11-29T09:27:23.214Z","status":"running", + """.trim() + """ + "completed_operators":[{"operator":"LuceneSource","status": + """.trim() + LuceneSourceOperatorStatusTests.simpleToJson() + """ + },{"operator":"ValuesSourceReader","status": + """.trim() + ValuesSourceReaderOperatorStatusTests.simpleToJson() + """ + }],"active_operators":[{"operator":"ExchangeSink","status": + """.trim() + ExchangeSinkOperatorStatusTests.simpleToJson() + "}]}")); } @Override @@ -58,7 +57,7 @@ protected Writeable.Reader instanceReader() { @Override protected DriverStatus createTestInstance() { - return new DriverStatus(randomSessionId(), randomLong(), randomStatus(), randomActiveOperators()); + return new DriverStatus(randomSessionId(), randomLong(), randomStatus(), randomOperatorStatuses(), randomOperatorStatuses()); } private String randomSessionId() { @@ -69,14 +68,15 @@ private DriverStatus.Status randomStatus() { return randomFrom(DriverStatus.Status.values()); } - private List randomActiveOperators() { - return randomList(0, 5, this::randomOperatorStatus); + static List randomOperatorStatuses() { + return randomList(0, 5, DriverStatusTests::randomOperatorStatus); } - private DriverStatus.OperatorStatus randomOperatorStatus() { + private static DriverStatus.OperatorStatus randomOperatorStatus() { Supplier status = randomFrom( new LuceneSourceOperatorStatusTests()::createTestInstance, new ValuesSourceReaderOperatorStatusTests()::createTestInstance, + new ExchangeSinkOperatorStatusTests()::createTestInstance, () -> null ); return new DriverStatus.OperatorStatus(randomAlphaOfLength(3), status.get()); @@ -87,8 +87,9 @@ protected DriverStatus mutateInstance(DriverStatus instance) throws IOException var sessionId = instance.sessionId(); long lastUpdated = instance.lastUpdated(); var status = instance.status(); - var operators = instance.activeOperators(); - switch (between(0, 3)) { + var completedOperators = instance.completedOperators(); + var activeOperators = instance.activeOperators(); + switch (between(0, 4)) { case 0: sessionId = randomValueOtherThan(sessionId, this::randomSessionId); break; @@ -99,16 +100,21 @@ protected DriverStatus mutateInstance(DriverStatus instance) throws IOException status = randomValueOtherThan(status, this::randomStatus); break; case 3: - operators = randomValueOtherThan(operators, this::randomActiveOperators); + completedOperators = randomValueOtherThan(completedOperators, DriverStatusTests::randomOperatorStatuses); + break; + case 4: + activeOperators = randomValueOtherThan(activeOperators, DriverStatusTests::randomOperatorStatuses); break; default: throw new UnsupportedOperationException(); } - return new DriverStatus(sessionId, lastUpdated, status, operators); + return new DriverStatus(sessionId, lastUpdated, status, completedOperators, activeOperators); } @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(List.of(LuceneSourceOperator.Status.ENTRY, ValuesSourceReaderOperator.Status.ENTRY)); + return new NamedWriteableRegistry( + List.of(LuceneSourceOperator.Status.ENTRY, ValuesSourceReaderOperator.Status.ENTRY, ExchangeSinkOperator.Status.ENTRY) + ); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperatorStatusTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperatorStatusTests.java index f342720b99903..7438055284b14 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperatorStatusTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkOperatorStatusTests.java @@ -17,8 +17,16 @@ public class ExchangeSinkOperatorStatusTests extends AbstractWireSerializingTestCase { public void testToXContent() { - assertThat(Strings.toString(new ExchangeSinkOperator.Status(10)), equalTo(""" - {"pages_accepted":10}""")); + assertThat(Strings.toString(simple()), equalTo(simpleToJson())); + } + + public static ExchangeSinkOperator.Status simple() { + return new ExchangeSinkOperator.Status(10); + } + + public static String simpleToJson() { + return """ + {"pages_accepted":10}"""; } @Override @@ -27,7 +35,7 @@ protected Writeable.Reader instanceReader() { } @Override - protected ExchangeSinkOperator.Status createTestInstance() { + public ExchangeSinkOperator.Status createTestInstance() { return new ExchangeSinkOperator.Status(between(0, Integer.MAX_VALUE)); } diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/GenerativeIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/GenerativeIT.java new file mode 100644 index 0000000000000..e499b13bf1db8 --- /dev/null +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/GenerativeIT.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.qa.single_node; + +import org.apache.lucene.tests.util.LuceneTestCase.AwaitsFix; +import org.elasticsearch.xpack.esql.qa.rest.generative.GenerativeRestTest; + +@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102084") +public class GenerativeIT extends GenerativeRestTest {} diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java new file mode 100644 index 0000000000000..25530e3d744ad --- /dev/null +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java @@ -0,0 +1,401 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.qa.rest.generative; + +import org.elasticsearch.xpack.esql.CsvTestsDataLoader; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; +import static org.elasticsearch.test.ESTestCase.randomBoolean; +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.elasticsearch.test.ESTestCase.randomIntBetween; +import static org.elasticsearch.test.ESTestCase.randomLongBetween; + +public class EsqlQueryGenerator { + + public record Column(String name, String type) {} + + public record QueryExecuted(String query, int depth, List outputSchema, Exception exception) {} + + public static String sourceCommand(List availabeIndices) { + return switch (randomIntBetween(0, 2)) { + case 0 -> from(availabeIndices); + case 1 -> showFunctions(); + default -> row(); + }; + + } + + /** + * @param previousOutput a list of fieldName+type + * @param policies + * @return a new command that can process it as input + */ + public static String pipeCommand(List previousOutput, List policies) { + return switch (randomIntBetween(0, 11)) { + case 0 -> dissect(previousOutput); + case 1 -> drop(previousOutput); + case 2 -> enrich(previousOutput, policies); + case 3 -> eval(previousOutput); + case 4 -> grok(previousOutput); + case 5 -> keep(previousOutput); + case 6 -> limit(); + case 7 -> mvExpand(previousOutput); + case 8 -> rename(previousOutput); + case 9 -> sort(previousOutput); + case 10 -> stats(previousOutput); + default -> where(previousOutput); + }; + } + + private static String where(List previousOutput) { + // TODO more complex conditions + StringBuilder result = new StringBuilder(" | where "); + int nConditions = randomIntBetween(1, 5); + for (int i = 0; i < nConditions; i++) { + String exp = booleanExpression(previousOutput); + if (exp == null) { + // cannot generate expressions, just skip + return ""; + } + if (i > 0) { + result.append(randomBoolean() ? " AND " : " OR "); + } + if (randomBoolean()) { + result.append(" NOT "); + } + result.append(exp); + } + + return result.toString(); + } + + private static String booleanExpression(List previousOutput) { + // TODO LIKE, RLIKE, functions etc. + return switch (randomIntBetween(0, 3)) { + case 0 -> { + String field = randomNumericField(previousOutput); + if (field == null) { + yield null; + } + yield field + " " + mathCompareOperator() + " 50"; + } + case 1 -> "true"; + default -> "false"; + }; + } + + private static String mathCompareOperator() { + return switch (randomIntBetween(0, 5)) { + case 0 -> "=="; + case 1 -> ">"; + case 2 -> ">="; + case 3 -> "<"; + case 4 -> "<="; + default -> "!="; + }; + } + + private static String enrich(List previousOutput, List policies) { + String field = randomKeywordField(previousOutput); + if (field == null || policies.isEmpty()) { + return ""; + } + // TODO add WITH + return " | enrich " + randomFrom(policies).policyName() + " on " + field; + } + + private static String grok(List previousOutput) { + String field = randomStringField(previousOutput); + if (field == null) { + return "";// no strings to grok, just skip + } + StringBuilder result = new StringBuilder(" | grok "); + result.append(field); + result.append(" \""); + for (int i = 0; i < randomIntBetween(1, 3); i++) { + if (i > 0) { + result.append(" "); + } + result.append("%{WORD:"); + if (randomBoolean()) { + result.append(randomAlphaOfLength(5)); + } else { + result.append(randomName(previousOutput)); + } + result.append("}"); + } + result.append("\""); + return result.toString(); + } + + private static String dissect(List previousOutput) { + String field = randomStringField(previousOutput); + if (field == null) { + return "";// no strings to dissect, just skip + } + StringBuilder result = new StringBuilder(" | dissect "); + result.append(field); + result.append(" \""); + for (int i = 0; i < randomIntBetween(1, 3); i++) { + if (i > 0) { + result.append(" "); + } + result.append("%{"); + if (randomBoolean()) { + result.append(randomAlphaOfLength(5)); + } else { + result.append(randomName(previousOutput)); + } + result.append("}"); + } + result.append("\""); + return result.toString(); + } + + private static String keep(List previousOutput) { + int n = randomIntBetween(1, previousOutput.size()); + Set proj = new HashSet<>(); + for (int i = 0; i < n; i++) { + if (randomIntBetween(0, 100) < 5) { + proj.add("*"); + } else { + String name = randomName(previousOutput); + if (name.length() > 1 && randomIntBetween(0, 100) < 10) { + if (randomBoolean()) { + name = name.substring(0, randomIntBetween(1, name.length() - 1)) + "*"; + } else { + name = "*" + name.substring(randomIntBetween(1, name.length() - 1)); + } + } + proj.add(name); + } + } + return " | keep " + proj.stream().collect(Collectors.joining(", ")); + } + + private static String randomName(List previousOutput) { + return previousOutput.get(randomIntBetween(0, previousOutput.size() - 1)).name(); + } + + private static String rename(List previousOutput) { + int n = randomIntBetween(1, Math.min(3, previousOutput.size())); + List proj = new ArrayList<>(); + List names = new ArrayList<>(previousOutput.stream().map(Column::name).collect(Collectors.toList())); + for (int i = 0; i < n; i++) { + String name = names.remove(randomIntBetween(0, names.size() - 1)); + String newName; + if (names.isEmpty() || randomBoolean()) { + newName = randomAlphaOfLength(5); + } else { + newName = names.get(randomIntBetween(0, names.size() - 1)); + } + names.add(newName); + proj.add(name + " AS " + newName); + } + return " | rename " + proj.stream().collect(Collectors.joining(", ")); + } + + private static String drop(List previousOutput) { + if (previousOutput.size() < 2) { + return ""; // don't drop all of them, just do nothing + } + int n = randomIntBetween(1, previousOutput.size() - 1); + Set proj = new HashSet<>(); + for (int i = 0; i < n; i++) { + String name = randomName(previousOutput); + if (name.length() > 1 && randomIntBetween(0, 100) < 10) { + if (randomBoolean()) { + name = name.substring(0, randomIntBetween(1, name.length() - 1)) + "*"; + } else { + name = "*" + name.substring(randomIntBetween(1, name.length() - 1)); + } + } + proj.add(name); + } + return " | drop " + proj.stream().collect(Collectors.joining(", ")); + } + + private static String sort(List previousOutput) { + int n = randomIntBetween(1, previousOutput.size()); + Set proj = new HashSet<>(); + for (int i = 0; i < n; i++) { + proj.add(randomName(previousOutput)); + } + return " | sort " + + proj.stream() + .map(x -> x + randomFrom("", " ASC", " DESC") + randomFrom("", " NULLS FIRST", " NULLS LAST")) + .collect(Collectors.joining(", ")); + } + + private static String mvExpand(List previousOutput) { + return " | mv_expand " + randomName(previousOutput); + } + + private static String eval(List previousOutput) { + StringBuilder cmd = new StringBuilder(" | eval "); + int nFields = randomIntBetween(1, 10); + // TODO pass newly created fields to next expressions + for (int i = 0; i < nFields; i++) { + String name; + if (randomBoolean()) { + name = randomAlphaOfLength(randomIntBetween(3, 10)); + } else { + name = randomName(previousOutput); + } + String expression = expression(previousOutput); + if (i > 0) { + cmd.append(","); + } + cmd.append(" "); + cmd.append(name); + cmd.append(" = "); + cmd.append(expression); + } + return cmd.toString(); + } + + private static String stats(List previousOutput) { + List nonNull = previousOutput.stream().filter(x -> x.type().equals("null") == false).collect(Collectors.toList()); + if (nonNull.isEmpty()) { + return ""; // cannot do any stats, just skip + } + StringBuilder cmd = new StringBuilder(" | stats "); + int nStats = randomIntBetween(1, 5); + for (int i = 0; i < nStats; i++) { + String name; + if (randomBoolean()) { + name = randomAlphaOfLength(randomIntBetween(3, 10)); + } else { + name = randomName(previousOutput); + } + String expression = agg(nonNull); + if (i > 0) { + cmd.append(","); + } + cmd.append(" "); + cmd.append(name); + cmd.append(" = "); + cmd.append(expression); + } + if (randomBoolean()) { + cmd.append(" by "); + + cmd.append(randomName(nonNull)); + } + return cmd.toString(); + } + + private static String agg(List previousOutput) { + String name = randomNumericOrDateField(previousOutput); + if (name != null && randomBoolean()) { + // numerics only + return switch (randomIntBetween(0, 1)) { + case 0 -> "max(" + name + ")"; + default -> "min(" + name + ")"; + // TODO more numerics + }; + } + // all types + name = randomName(previousOutput); + return switch (randomIntBetween(0, 2)) { + case 0 -> "count(*)"; + case 1 -> "count(" + name + ")"; + default -> "count_distinct(" + name + ")"; + }; + } + + private static String randomNumericOrDateField(List previousOutput) { + return randomName(previousOutput, Set.of("long", "integer", "double", "date")); + } + + private static String randomNumericField(List previousOutput) { + return randomName(previousOutput, Set.of("long", "integer", "double")); + } + + private static String randomStringField(List previousOutput) { + return randomName(previousOutput, Set.of("text", "keyword")); + } + + private static String randomKeywordField(List previousOutput) { + return randomName(previousOutput, Set.of("keyword")); + } + + private static String randomName(List cols, Set allowedTypes) { + List items = cols.stream().filter(x -> allowedTypes.contains(x.type())).map(Column::name).collect(Collectors.toList()); + if (items.size() == 0) { + return null; + } + return items.get(randomIntBetween(0, items.size() - 1)); + } + + private static String expression(List previousOutput) { + // TODO improve!!! + return constantExpression(); + } + + public static String limit() { + return " | limit " + randomIntBetween(0, 15000); + } + + private static String from(List availabeIndices) { + StringBuilder result = new StringBuilder("from "); + int items = randomIntBetween(1, 3); + for (int i = 0; i < items; i++) { + String pattern = indexPattern(availabeIndices.get(randomIntBetween(0, availabeIndices.size() - 1))); + if (i > 0) { + result.append(","); + } + result.append(pattern); + } + return result.toString(); + } + + private static String showFunctions() { + return "show functions"; + } + + private static String indexPattern(String indexName) { + return randomBoolean() ? indexName : indexName.substring(0, randomIntBetween(0, indexName.length())) + "*"; + } + + private static String row() { + StringBuilder cmd = new StringBuilder("row "); + int nFields = randomIntBetween(1, 10); + for (int i = 0; i < nFields; i++) { + String name = randomAlphaOfLength(randomIntBetween(3, 10)); + String expression = constantExpression(); + if (i > 0) { + cmd.append(","); + } + cmd.append(" "); + cmd.append(name); + cmd.append(" = "); + cmd.append(expression); + } + return cmd.toString(); + } + + private static String constantExpression() { + // TODO not only simple values, but also foldable expressions + return switch (randomIntBetween(0, 4)) { + case 0 -> "" + randomIntBetween(Integer.MIN_VALUE, Integer.MAX_VALUE); + case 1 -> "" + randomLongBetween(Long.MIN_VALUE, Long.MAX_VALUE); + case 2 -> "\"" + randomAlphaOfLength(randomIntBetween(0, 20)) + "\""; + case 3 -> "" + randomBoolean(); + default -> "null"; + }; + + } + +} diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java new file mode 100644 index 0000000000000..9ba54ea1941fd --- /dev/null +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.qa.rest.generative; + +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.esql.CsvTestsDataLoader; +import org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.CSV_DATASET_MAP; +import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.ENRICH_POLICIES; +import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs; + +public abstract class GenerativeRestTest extends ESRestTestCase { + + public static final int ITERATIONS = 50; + public static final int MAX_DEPTH = 10; + + public static final Set ALLOWED_ERRORS = Set.of( + "is ambiguous (to disambiguate use quotes or qualifiers)", + "due to ambiguities being mapped as" + ); + + @Before + public void setup() throws IOException { + if (indexExists(CSV_DATASET_MAP.keySet().iterator().next()) == false) { + loadDataSetIntoEs(client()); + } + } + + public void test() { + List indices = availableIndices(); + List policies = availableEnrichPolicies(); + for (int i = 0; i < ITERATIONS; i++) { + String command = EsqlQueryGenerator.sourceCommand(indices); + EsqlQueryGenerator.QueryExecuted result = execute(command, 0); + if (result.exception() != null) { + checkException(result); + continue; + } + for (int j = 0; j < MAX_DEPTH; j++) { + if (result.outputSchema().isEmpty()) { + break; + } + command = EsqlQueryGenerator.pipeCommand(result.outputSchema(), policies); + result = execute(result.query() + command, result.depth() + 1); + if (result.exception() != null) { + checkException(result); + break; + } + } + } + } + + private void checkException(EsqlQueryGenerator.QueryExecuted query) { + for (String allowedError : ALLOWED_ERRORS) { + if (query.exception().getMessage().contains(allowedError)) { + return; + } + } + fail("query: " + query.query() + "\nexception: " + query.exception().getMessage()); + } + + private EsqlQueryGenerator.QueryExecuted execute(String command, int depth) { + try { + Map a = RestEsqlTestCase.runEsql(new RestEsqlTestCase.RequestObjectBuilder().query(command).build()); + List outputSchema = outputSchema(a); + return new EsqlQueryGenerator.QueryExecuted(command, depth, outputSchema, null); + } catch (Exception e) { + return new EsqlQueryGenerator.QueryExecuted(command, depth, null, e); + } + + } + + @SuppressWarnings("unchecked") + private List outputSchema(Map a) { + List> cols = (List>) a.get("columns"); + if (cols == null) { + return null; + } + return cols.stream().map(x -> new EsqlQueryGenerator.Column(x.get("name"), x.get("type"))).collect(Collectors.toList()); + } + + private List availableIndices() { + return new ArrayList<>(CSV_DATASET_MAP.keySet()); + } + + List availableEnrichPolicies() { + return ENRICH_POLICIES; + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index d35f7898d937f..eca8beb06576b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -105,7 +105,8 @@ public static EsqlConfiguration configuration(QueryPragmas pragmas, String query pragmas, EsqlPlugin.QUERY_RESULT_TRUNCATION_MAX_SIZE.getDefault(Settings.EMPTY), EsqlPlugin.QUERY_RESULT_TRUNCATION_DEFAULT_SIZE.getDefault(Settings.EMPTY), - query + query, + false ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequest.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequest.java index c467f0dfc9075..0de89a4d8de2a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequest.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryRequest.java @@ -62,11 +62,13 @@ public class EsqlQueryRequest extends ActionRequest implements CompositeIndicesR private static final ParseField PRAGMA_FIELD = new ParseField("pragma"); private static final ParseField PARAMS_FIELD = new ParseField("params"); private static final ParseField LOCALE_FIELD = new ParseField("locale"); + private static final ParseField PROFILE_FIELD = new ParseField("profile"); private static final ObjectParser PARSER = objectParser(EsqlQueryRequest::new); private String query; private boolean columnar; + private boolean profile; private Locale locale; private QueryBuilder filter; private QueryPragmas pragmas = new QueryPragmas(Settings.EMPTY); @@ -106,6 +108,21 @@ public boolean columnar() { return columnar; } + /** + * Enable profiling, sacrificing performance to return information about + * what operations are taking the most time. + */ + public void profile(boolean profile) { + this.profile = profile; + } + + /** + * Is profiling enabled? + */ + public boolean profile() { + return profile; + } + public void locale(Locale locale) { this.locale = locale; } @@ -154,6 +171,7 @@ private static ObjectParser objectParser(Supplier request.locale(Locale.forLanguageTag(localeTag)), LOCALE_FIELD); + parser.declareBoolean(EsqlQueryRequest::profile, PROFILE_FIELD); return parser; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java index fea9dd6c526c3..b283231574540 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java @@ -8,15 +8,18 @@ package org.elasticsearch.xpack.esql.action; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.common.xcontent.ChunkedToXContentObject; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; @@ -28,6 +31,8 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.UnsupportedValueSource; +import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.DocValueFormat; @@ -62,12 +67,7 @@ import static org.elasticsearch.xpack.ql.util.SpatialCoordinateTypes.GEO; import static org.elasticsearch.xpack.ql.util.StringUtils.parseIP; -public class EsqlQueryResponse extends ActionResponse implements ChunkedToXContent, Releasable { - - private final List columns; - private final List pages; - private final boolean columnar; - +public class EsqlQueryResponse extends ActionResponse implements ChunkedToXContentObject, Releasable { private static final InstantiatingObjectParser PARSER; static { InstantiatingObjectParser.Builder parser = InstantiatingObjectParser.builder( @@ -80,15 +80,22 @@ public class EsqlQueryResponse extends ActionResponse implements ChunkedToXConte PARSER = parser.build(); } - public EsqlQueryResponse(List columns, List pages, boolean columnar) { + private final List columns; + private final List pages; + private final Profile profile; + private final boolean columnar; + + public EsqlQueryResponse(List columns, List pages, @Nullable Profile profile, boolean columnar) { this.columns = columns; this.pages = pages; + this.profile = profile; this.columnar = columnar; } public EsqlQueryResponse(List columns, List> values) { this.columns = columns; this.pages = List.of(valuesToPage(columns.stream().map(ColumnInfo::type).toList(), values)); + this.profile = null; this.columnar = false; } @@ -99,10 +106,15 @@ public static Writeable.Reader reader(BlockFactory blockFacto return in -> new EsqlQueryResponse(new BlockStreamInput(in, blockFactory)); } - public EsqlQueryResponse(BlockStreamInput in) throws IOException { + private EsqlQueryResponse(BlockStreamInput in) throws IOException { super(in); this.columns = in.readCollectionAsList(ColumnInfo::new); this.pages = in.readCollectionAsList(Page::new); + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + this.profile = in.readOptionalWriteable(Profile::new); + } else { + this.profile = null; + } this.columnar = in.readBoolean(); } @@ -110,6 +122,9 @@ public EsqlQueryResponse(BlockStreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeCollection(columns); out.writeCollection(pages); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + out.writeOptionalWriteable(profile); + } out.writeBoolean(columnar); } @@ -125,12 +140,16 @@ public Iterator> values() { return pagesToValues(columns.stream().map(ColumnInfo::type).toList(), pages); } + public Profile profile() { + return profile; + } + public boolean columnar() { return columnar; } @Override - public Iterator toXContentChunked(ToXContent.Params unused) { + public Iterator toXContentChunked(ToXContent.Params params) { final BytesRef scratch = new BytesRef(); final Iterator valuesIt; if (pages.isEmpty()) { @@ -141,14 +160,14 @@ public Iterator toXContentChunked(ToXContent.Params unused 0, columns().size(), column -> Iterators.concat( - Iterators.single(((builder, params) -> builder.startArray())), + Iterators.single(((builder, p) -> builder.startArray())), Iterators.flatMap(pages.iterator(), page -> { ColumnInfo.PositionToXContent toXContent = columns.get(column) .positionToXContent(page.getBlock(column), scratch); return Iterators.forRange( 0, page.getPositionCount(), - position -> (builder, params) -> toXContent.positionToXContent(builder, params, position) + position -> (builder, p) -> toXContent.positionToXContent(builder, p, position) ); }), ChunkedToXContentHelper.endArray() @@ -164,22 +183,32 @@ public Iterator toXContentChunked(ToXContent.Params unused for (int column = 0; column < columnCount; column++) { toXContents[column] = columns.get(column).positionToXContent(page.getBlock(column), scratch); } - return Iterators.forRange(0, page.getPositionCount(), position -> (builder, params) -> { + return Iterators.forRange(0, page.getPositionCount(), position -> (builder, p) -> { builder.startArray(); for (int c = 0; c < columnCount; c++) { - toXContents[c].positionToXContent(builder, params, position); + toXContents[c].positionToXContent(builder, p, position); } return builder.endArray(); }); }); } - return Iterators.concat(ChunkedToXContentHelper.startObject(), ChunkedToXContentHelper.singleChunk((builder, params) -> { + Iterator columnsRender = ChunkedToXContentHelper.singleChunk((builder, p) -> { builder.startArray("columns"); for (ColumnInfo col : columns) { - col.toXContent(builder, params); + col.toXContent(builder, p); } return builder.endArray(); - }), ChunkedToXContentHelper.array("values", valuesIt), ChunkedToXContentHelper.endObject()); + }); + Iterator profileRender = profile == null + ? List.of().iterator() + : ChunkedToXContentHelper.field("profile", profile, params); + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + columnsRender, + ChunkedToXContentHelper.array("values", valuesIt), + profileRender, + ChunkedToXContentHelper.endObject() + ); } @Override @@ -198,7 +227,8 @@ public boolean equals(Object o) { EsqlQueryResponse that = (EsqlQueryResponse) o; return Objects.equals(columns, that.columns) && columnar == that.columnar - && Iterators.equals(values(), that.values(), (row1, row2) -> Iterators.equals(row1, row2, Objects::equals)); + && Iterators.equals(values(), that.values(), (row1, row2) -> Iterators.equals(row1, row2, Objects::equals)) + && Objects.equals(profile, that.profile); } @Override @@ -336,4 +366,51 @@ private static Page valuesToPage(List dataTypes, List> valu } return new Page(results.stream().map(Block.Builder::build).toArray(Block[]::new)); } + + public static class Profile implements Writeable, ChunkedToXContentObject { + private final List drivers; + + public Profile(List drivers) { + this.drivers = drivers; + } + + public Profile(StreamInput in) throws IOException { + this.drivers = in.readCollectionAsImmutableList(DriverProfile::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(drivers); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Profile profile = (Profile) o; + return Objects.equals(drivers, profile.drivers); + } + + @Override + public int hashCode() { + return Objects.hash(drivers); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.array("drivers", drivers.iterator(), params), + ChunkedToXContentHelper.endObject() + ); + } + + List drivers() { + return drivers; + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index c28867f89c981..8d7024f7d889d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.plugin; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.search.SearchRequest; @@ -28,6 +29,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.DriverTaskRunner; import org.elasticsearch.compute.operator.exchange.ExchangeResponse; import org.elasticsearch.compute.operator.exchange.ExchangeService; @@ -86,6 +88,8 @@ * Computes the result of a {@link PhysicalPlan}. */ public class ComputeService { + public record Result(List pages, List profiles) {} + private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); private final SearchService searchService; private final BigArrays bigArrays; @@ -122,7 +126,7 @@ public void execute( CancellableTask rootTask, PhysicalPlan physicalPlan, EsqlConfiguration configuration, - ActionListener> listener + ActionListener listener ) { Tuple coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode( physicalPlan, @@ -142,7 +146,12 @@ public void execute( if (concreteIndices.isEmpty()) { var computeContext = new ComputeContext(sessionId, List.of(), configuration, null, null); - runCompute(rootTask, computeContext, coordinatorPlan, listener.map(unused -> collectedPages)); + runCompute( + rootTask, + computeContext, + coordinatorPlan, + listener.map(driverProfiles -> new Result(collectedPages, driverProfiles)) + ); return; } QueryBuilder requestFilter = PlannerUtils.requestFilter(dataNodePlan); @@ -161,18 +170,32 @@ public void execute( queryPragmas.exchangeBufferSize(), ESQL_THREAD_POOL_NAME ); + final List collectedProfiles = configuration.profile() + ? Collections.synchronizedList(new ArrayList<>()) + : null; try ( Releasable ignored = exchangeSource::decRef; - RefCountingListener requestRefs = new RefCountingListener(delegate.map(unused -> collectedPages)) + RefCountingListener requestRefs = new RefCountingListener( + delegate.map(unused -> new Result(collectedPages, collectedProfiles)) + ) ) { final AtomicBoolean cancelled = new AtomicBoolean(); // wait until the source handler is completed exchangeSource.addCompletionListener(requestRefs.acquire()); // run compute on the coordinator var computeContext = new ComputeContext(sessionId, List.of(), configuration, exchangeSource, null); - runCompute(rootTask, computeContext, coordinatorPlan, cancelOnFailure(rootTask, cancelled, requestRefs.acquire())); + runCompute( + rootTask, + computeContext, + coordinatorPlan, + cancelOnFailure(rootTask, cancelled, requestRefs.acquire()).map(driverProfiles -> { + if (configuration.profile()) { + collectedProfiles.addAll(driverProfiles); + } + return null; + }) + ); // run compute on remote nodes - // TODO: This is wrong, we need to be able to cancel runComputeOnRemoteNodes( sessionId, rootTask, @@ -180,7 +203,12 @@ public void execute( dataNodePlan, exchangeSource, targetNodes, - () -> cancelOnFailure(rootTask, cancelled, requestRefs.acquire()).map(unused -> null) + () -> cancelOnFailure(rootTask, cancelled, requestRefs.acquire()).map(response -> { + if (configuration.profile()) { + collectedProfiles.addAll(response.profiles); + } + return null; + }) ); } }) @@ -241,7 +269,7 @@ private ActionListener cancelOnFailure(CancellableTask task, AtomicBoolean }); } - void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener listener) { + void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener> listener) { listener = ActionListener.runAfter(listener, () -> Releasables.close(context.searchContexts)); final List drivers; try { @@ -273,11 +301,18 @@ void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, listener.onFailure(e); return; } + ActionListener listenerCollectingStatus = listener.map(ignored -> { + if (context.configuration.profile()) { + return drivers.stream().map(d -> new DriverProfile(d.status().completedOperators())).toList(); + } + return null; + }); + listenerCollectingStatus = ActionListener.releaseAfter(listenerCollectingStatus, () -> Releasables.close(drivers)); driverRunner.executeDrivers( task, drivers, transportService.getThreadPool().executor(ESQL_WORKER_THREAD_POOL_NAME), - ActionListener.releaseAfter(listener, () -> Releasables.close(drivers)) + listenerCollectingStatus ); } @@ -412,17 +447,36 @@ private void computeTargetNodes( } } - // TODO: To include stats/profiles private static class DataNodeResponse extends TransportResponse { - DataNodeResponse() {} + private final List profiles; + + DataNodeResponse(List profiles) { + this.profiles = profiles; + } DataNodeResponse(StreamInput in) throws IOException { super(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + if (in.readBoolean()) { + profiles = in.readCollectionAsImmutableList(DriverProfile::new); + } else { + profiles = null; + } + } else { + profiles = null; + } } @Override - public void writeTo(StreamOutput out) { - + public void writeTo(StreamOutput out) throws IOException { + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + if (profiles == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeCollection(profiles); + } + } } } @@ -436,13 +490,16 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T final var sessionId = request.sessionId(); final var exchangeSink = exchangeService.getSinkHandler(sessionId); parentTask.addListener(() -> exchangeService.finishSinkHandler(sessionId, new TaskCancelledException("task cancelled"))); - final ActionListener listener = new OwningChannelActionListener<>(channel).map(nullValue -> new DataNodeResponse()); + final ActionListener listener = new OwningChannelActionListener<>(channel); acquireSearchContexts(request.shardIds(), request.aliasFilters(), ActionListener.wrap(searchContexts -> { var computeContext = new ComputeContext(sessionId, searchContexts, request.configuration(), null, exchangeSink); - runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(unused -> { + runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(driverProfiles -> { // don't return until all pages are fetched exchangeSink.addCompletionListener( - ActionListener.releaseAfter(listener, () -> exchangeService.finishSinkHandler(sessionId, null)) + ActionListener.releaseAfter( + listener.map(nullValue -> new DataNodeResponse(driverProfiles)), + () -> exchangeService.finishSinkHandler(sessionId, null) + ) ); }, e -> { exchangeService.finishSinkHandler(sessionId, e); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index de4af3497d80d..780d812e2c23b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -96,7 +96,8 @@ private void doExecuteForked(Task task, EsqlQueryRequest request, ActionListener request.pragmas(), clusterService.getClusterSettings().get(EsqlPlugin.QUERY_RESULT_TRUNCATION_MAX_SIZE), clusterService.getClusterSettings().get(EsqlPlugin.QUERY_RESULT_TRUNCATION_DEFAULT_SIZE), - request.query() + request.query(), + request.profile() ); String sessionId = sessionID(task); planExecutor.esql( @@ -110,12 +111,15 @@ private void doExecuteForked(Task task, EsqlQueryRequest request, ActionListener (CancellableTask) task, physicalPlan, configuration, - delegate.map(pages -> { + delegate.map(result -> { List columns = physicalPlan.output() .stream() .map(c -> new ColumnInfo(c.qualifiedName(), EsqlDataTypes.outputType(c.dataType()))) .toList(); - return new EsqlQueryResponse(columns, pages, request.columnar()); + EsqlQueryResponse.Profile profile = configuration.profile() + ? new EsqlQueryResponse.Profile(result.profiles()) + : null; + return new EsqlQueryResponse(columns, result.pages(), profile, request.columnar()); }) ) ) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlConfiguration.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlConfiguration.java index 7549552dae55b..ac13f25c2d2a9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlConfiguration.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlConfiguration.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.session; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.compress.CompressorFactory; import org.elasticsearch.common.io.stream.StreamInput; @@ -37,6 +38,8 @@ public class EsqlConfiguration extends Configuration implements Writeable { private final String query; + private final boolean profile; + public EsqlConfiguration( ZoneId zi, Locale locale, @@ -45,7 +48,8 @@ public EsqlConfiguration( QueryPragmas pragmas, int resultTruncationMaxSize, int resultTruncationDefaultSize, - String query + String query, + boolean profile ) { super(zi, username, clusterName); this.locale = locale; @@ -53,6 +57,7 @@ public EsqlConfiguration( this.resultTruncationMaxSize = resultTruncationMaxSize; this.resultTruncationDefaultSize = resultTruncationDefaultSize; this.query = query; + this.profile = profile; } public EsqlConfiguration(StreamInput in) throws IOException { @@ -62,6 +67,11 @@ public EsqlConfiguration(StreamInput in) throws IOException { this.resultTruncationMaxSize = in.readVInt(); this.resultTruncationDefaultSize = in.readVInt(); this.query = readQuery(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + this.profile = in.readBoolean(); + } else { + this.profile = false; + } } @Override @@ -77,6 +87,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVInt(resultTruncationMaxSize); out.writeVInt(resultTruncationDefaultSize); writeQuery(out, query); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE)) { + out.writeBoolean(profile); + } } public QueryPragmas pragmas() { @@ -99,6 +112,14 @@ public String query() { return query; } + /** + * Enable profiling, sacrificing performance to return information about + * what operations are taking the most time. + */ + public boolean profile() { + return profile; + } + private static void writeQuery(StreamOutput out, String query) throws IOException { if (query.length() > QUERY_COMPRESS_THRESHOLD_CHARS) { // compare on chars to avoid UTF-8 encoding unless actually required out.writeBoolean(true); @@ -130,13 +151,14 @@ public boolean equals(Object o) { && resultTruncationDefaultSize == that.resultTruncationDefaultSize && Objects.equals(pragmas, that.pragmas) && Objects.equals(locale, that.locale) - && Objects.equals(that.query, query); + && Objects.equals(that.query, query) + && profile == that.profile; } return false; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), pragmas, resultTruncationMaxSize, resultTruncationDefaultSize, locale, query); + return Objects.hash(super.hashCode(), pragmas, resultTruncationMaxSize, resultTruncationDefaultSize, locale, query, profile); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseProfileTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseProfileTests.java new file mode 100644 index 0000000000000..af8f6dcd550c4 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseProfileTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.AbstractPageMappingOperator; +import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.compute.operator.DriverStatus; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.List; +import java.util.stream.Stream; + +public class EsqlQueryResponseProfileTests extends AbstractWireSerializingTestCase { + @Override + protected Writeable.Reader instanceReader() { + return EsqlQueryResponse.Profile::new; + } + + @Override + protected EsqlQueryResponse.Profile createTestInstance() { + return new EsqlQueryResponse.Profile(randomDriverProfiles()); + } + + @Override + protected EsqlQueryResponse.Profile mutateInstance(EsqlQueryResponse.Profile instance) { + return new EsqlQueryResponse.Profile(randomValueOtherThan(instance.drivers(), this::randomDriverProfiles)); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry( + Stream.concat(Stream.of(AbstractPageMappingOperator.Status.ENTRY), Block.getNamedWriteables().stream()).toList() + ); + } + + private List randomDriverProfiles() { + return randomList(10, this::randomDriverProfile); + } + + private DriverProfile randomDriverProfile() { + return new DriverProfile(randomList(10, this::randomOperatorStatus)); + } + + private DriverStatus.OperatorStatus randomOperatorStatus() { + String name = randomAlphaOfLength(4); + Operator.Status status = randomBoolean() ? null : new AbstractPageMappingOperator.Status(between(0, Integer.MAX_VALUE)); + return new DriverStatus.OperatorStatus(name, status); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java index 5b2aba2e9e1f3..f040933e01410 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java @@ -29,6 +29,9 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.lucene.UnsupportedValueSource; +import org.elasticsearch.compute.operator.AbstractPageMappingOperator; +import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.compute.operator.DriverStatus; import org.elasticsearch.core.Releasables; import org.elasticsearch.test.AbstractChunkedSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; @@ -46,6 +49,7 @@ import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.List; +import java.util.stream.Stream; import static org.elasticsearch.xpack.ql.util.SpatialCoordinateTypes.CARTESIAN; import static org.elasticsearch.xpack.ql.util.SpatialCoordinateTypes.GEO; @@ -67,27 +71,29 @@ public void blockFactoryEmpty() { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(Block.getNamedWriteables()); + return new NamedWriteableRegistry( + Stream.concat(Stream.of(AbstractPageMappingOperator.Status.ENTRY), Block.getNamedWriteables().stream()).toList() + ); } @Override protected EsqlQueryResponse createXContextTestInstance(XContentType xContentType) { // columnar param can't be different from the default value (false) since the EsqlQueryResponse will be serialized (by some random // XContentType, not to a StreamOutput) and parsed back, which doesn't preserve columnar field's value. - return randomResponse(false); + return randomResponse(false, null); } @Override protected EsqlQueryResponse createTestInstance() { - return randomResponse(randomBoolean()); + return randomResponse(randomBoolean(), randomProfile()); } - EsqlQueryResponse randomResponse(boolean columnar) { + EsqlQueryResponse randomResponse(boolean columnar, EsqlQueryResponse.Profile profile) { int noCols = randomIntBetween(1, 10); List columns = randomList(noCols, noCols, this::randomColumnInfo); int noPages = randomIntBetween(1, 20); List values = randomList(noPages, noPages, () -> randomPage(columns)); - return new EsqlQueryResponse(columns, values, columnar); + return new EsqlQueryResponse(columns, values, profile, columnar); } private ColumnInfo randomColumnInfo() { @@ -99,6 +105,13 @@ private ColumnInfo randomColumnInfo() { return new ColumnInfo(randomAlphaOfLength(10), type.esType()); } + private EsqlQueryResponse.Profile randomProfile() { + if (randomBoolean()) { + return null; + } + return new EsqlQueryResponseProfileTests().createTestInstance(); + } + private Page randomPage(List columns) { return new Page(columns.stream().map(c -> { Block.Builder builder = LocalExecutionPlanner.toElementType(EsqlDataTypes.fromName(c.type())).newBlockBuilder(1, blockFactory); @@ -148,23 +161,34 @@ protected EsqlQueryResponse mutateInstance(EsqlQueryResponse instance) { allNull = false; } } - return switch (allNull ? between(0, 1) : between(0, 2)) { + return switch (allNull ? between(0, 2) : between(0, 3)) { case 0 -> { int mutCol = between(0, instance.columns().size() - 1); List cols = new ArrayList<>(instance.columns()); // keep the type the same so the values are still valid but change the name cols.set(mutCol, new ColumnInfo(cols.get(mutCol).name() + "mut", cols.get(mutCol).type())); - yield new EsqlQueryResponse(cols, deepCopyOfPages(instance), instance.columnar()); + yield new EsqlQueryResponse(cols, deepCopyOfPages(instance), instance.profile(), instance.columnar()); } - case 1 -> new EsqlQueryResponse(instance.columns(), deepCopyOfPages(instance), false == instance.columnar()); - case 2 -> { + case 1 -> new EsqlQueryResponse( + instance.columns(), + deepCopyOfPages(instance), + instance.profile(), + false == instance.columnar() + ); + case 2 -> new EsqlQueryResponse( + instance.columns(), + deepCopyOfPages(instance), + randomValueOtherThan(instance.profile(), this::randomProfile), + instance.columnar() + ); + case 3 -> { int noPages = instance.pages().size(); List differentPages = List.of(); do { differentPages.forEach(p -> Releasables.closeExpectNoException(p::releaseBlocks)); differentPages = randomList(noPages, noPages, () -> randomPage(instance.columns())); } while (differentPages.equals(instance.pages())); - yield new EsqlQueryResponse(instance.columns(), differentPages, instance.columnar()); + yield new EsqlQueryResponse(instance.columns(), differentPages, instance.profile(), instance.columnar()); } default -> throw new IllegalArgumentException(); }; @@ -194,7 +218,7 @@ protected EsqlQueryResponse doParseInstance(XContentParser parser) { } public void testChunkResponseSizeColumnar() { - try (EsqlQueryResponse resp = randomResponse(true)) { + try (EsqlQueryResponse resp = randomResponse(true, null)) { int columnCount = resp.pages().get(0).getBlockCount(); int bodySize = resp.pages().stream().mapToInt(p -> p.getPositionCount() * p.getBlockCount()).sum() + columnCount * 2; assertChunkCount(resp, r -> 5 + bodySize); @@ -202,7 +226,7 @@ public void testChunkResponseSizeColumnar() { } public void testChunkResponseSizeRows() { - try (EsqlQueryResponse resp = randomResponse(false)) { + try (EsqlQueryResponse resp = randomResponse(false, null)) { int bodySize = resp.pages().stream().mapToInt(p -> p.getPositionCount()).sum(); assertChunkCount(resp, r -> 5 + bodySize); } @@ -226,10 +250,28 @@ private EsqlQueryResponse simple(boolean columnar) { return new EsqlQueryResponse( List.of(new ColumnInfo("foo", "integer")), List.of(new Page(new IntArrayVector(new int[] { 40, 80 }, 2).asBlock())), + null, columnar ); } + public void testProfileXContent() { + try ( + EsqlQueryResponse response = new EsqlQueryResponse( + List.of(new ColumnInfo("foo", "integer")), + List.of(new Page(new IntArrayVector(new int[] { 40, 80 }, 2).asBlock())), + new EsqlQueryResponse.Profile( + List.of(new DriverProfile(List.of(new DriverStatus.OperatorStatus("asdf", new AbstractPageMappingOperator.Status(10))))) + ), + false + ); + ) { + assertThat(Strings.toString(response), equalTo(""" + {"columns":[{"name":"foo","type":"integer"}],"values":[[40],[80]],"profile":{"drivers":[""" + """ + {"operators":[{"operator":"asdf","status":{"pages_processed":10}}]}]}}""")); + } + } + @Override protected void dispose(EsqlQueryResponse esqlQueryResponse) { esqlQueryResponse.close(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java index 95d8babcc5802..9430e984039fe 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatTests.java @@ -231,12 +231,12 @@ public void testPlainTextEmptyCursorWithColumns() { public void testPlainTextEmptyCursorWithoutColumns() { assertEquals( StringUtils.EMPTY, - getTextBodyContent(PLAIN_TEXT.format(req(), new EsqlQueryResponse(emptyList(), emptyList(), false))) + getTextBodyContent(PLAIN_TEXT.format(req(), new EsqlQueryResponse(emptyList(), emptyList(), null, false))) ); } private static EsqlQueryResponse emptyData() { - return new EsqlQueryResponse(singletonList(new ColumnInfo("name", "keyword")), emptyList(), false); + return new EsqlQueryResponse(singletonList(new ColumnInfo("name", "keyword")), emptyList(), null, false); } private static EsqlQueryResponse regularData() { @@ -259,7 +259,7 @@ private static EsqlQueryResponse regularData() { ) ); - return new EsqlQueryResponse(headers, values, false); + return new EsqlQueryResponse(headers, values, null, false); } private static EsqlQueryResponse escapedData() { @@ -277,7 +277,7 @@ private static EsqlQueryResponse escapedData() { ) ); - return new EsqlQueryResponse(headers, values, false); + return new EsqlQueryResponse(headers, values, null, false); } private static RestRequest req() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java index 558a92de70351..22e532341d30b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/formatter/TextFormatterTests.java @@ -61,6 +61,7 @@ public class TextFormatterTests extends ESTestCase { Block.constantNullBlock(2) ) ), + null, randomBoolean() ); @@ -123,6 +124,7 @@ public void testFormatWithoutHeader() { Block.constantNullBlock(2) ) ), + null, randomBoolean() ); @@ -161,6 +163,7 @@ public void testVeryLongPadding() { .build() ) ), + null, randomBoolean() ) ).format(false) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java index 0f6ed2d1ab3bb..b4c9d7a9baeca 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java @@ -74,7 +74,8 @@ public class EvalMapperTests extends ESTestCase { null, 10000000, 10000, - StringUtils.EMPTY + StringUtils.EMPTY, + false ); @ParametersFactory(argumentFormatting = "%1$s") diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index c8c8029f994cc..a01d82731bc94 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -142,7 +142,8 @@ private EsqlConfiguration config() { pragmas, EsqlPlugin.QUERY_RESULT_TRUNCATION_MAX_SIZE.getDefault(null), EsqlPlugin.QUERY_RESULT_TRUNCATION_DEFAULT_SIZE.getDefault(null), - StringUtils.EMPTY + StringUtils.EMPTY, + false ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlConfigurationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlConfigurationSerializationTests.java index aaa76c068f58a..9879f7c9ed23d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlConfigurationSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/EsqlConfigurationSerializationTests.java @@ -42,8 +42,19 @@ public static EsqlConfiguration randomConfiguration(String query) { var clusterName = randomAlphaOfLengthBetween(3, 10); var truncation = randomNonNegativeInt(); var defaultTruncation = randomNonNegativeInt(); + boolean profile = randomBoolean(); - return new EsqlConfiguration(zoneId, locale, username, clusterName, randomQueryPragmas(), truncation, defaultTruncation, query); + return new EsqlConfiguration( + zoneId, + locale, + username, + clusterName, + randomQueryPragmas(), + truncation, + defaultTruncation, + query, + profile + ); } @Override @@ -53,7 +64,7 @@ protected EsqlConfiguration createTestInstance() { @Override protected EsqlConfiguration mutateInstance(EsqlConfiguration in) throws IOException { - int ordinal = between(0, 7); + int ordinal = between(0, 8); return new EsqlConfiguration( ordinal == 0 ? randomValueOtherThan(in.zoneId(), () -> randomZone().normalized()) : in.zoneId(), ordinal == 1 ? randomValueOtherThan(in.locale(), () -> randomLocale(random())) : in.locale(), @@ -64,7 +75,8 @@ protected EsqlConfiguration mutateInstance(EsqlConfiguration in) throws IOExcept : in.pragmas(), ordinal == 5 ? in.resultTruncationMaxSize() + randomIntBetween(3, 10) : in.resultTruncationMaxSize(), ordinal == 6 ? in.resultTruncationDefaultSize() + randomIntBetween(3, 10) : in.resultTruncationDefaultSize(), - ordinal == 7 ? randomAlphaOfLength(100) : in.query() + ordinal == 7 ? randomAlphaOfLength(100) : in.query(), + ordinal == 8 ? in.profile() == false : in.profile() ); } } diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMigrateToDataTiersAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMigrateToDataTiersAction.java index ae0df89c9bb8f..8cc14a42eb5f3 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMigrateToDataTiersAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMigrateToDataTiersAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.Priority; @@ -46,6 +47,7 @@ public class TransportMigrateToDataTiersAction extends TransportMasterNodeAction private static final Logger logger = LogManager.getLogger(TransportMigrateToDataTiersAction.class); + private final RerouteService rerouteService; private final NamedXContentRegistry xContentRegistry; private final Client client; private final XPackLicenseState licenseState; @@ -54,6 +56,7 @@ public class TransportMigrateToDataTiersAction extends TransportMasterNodeAction public TransportMigrateToDataTiersAction( TransportService transportService, ClusterService clusterService, + RerouteService rerouteService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, @@ -72,6 +75,7 @@ public TransportMigrateToDataTiersAction( MigrateToDataTiersResponse::new, EsExecutors.DIRECT_EXECUTOR_SERVICE ); + this.rerouteService = rerouteService; this.xContentRegistry = xContentRegistry; this.client = client; this.licenseState = licenseState; @@ -141,20 +145,19 @@ public void onFailure(Exception e) { @Override public void clusterStateProcessed(ClusterState oldState, ClusterState newState) { - clusterService.getRerouteService() - .reroute("cluster migrated to data tiers routing", Priority.NORMAL, new ActionListener() { - @Override - public void onResponse(Void ignored) {} - - @Override - public void onFailure(Exception e) { - logger.log( - MasterService.isPublishFailureException(e) ? Level.DEBUG : Level.WARN, - "unsuccessful reroute after migration to data tiers routing", - e - ); - } - }); + rerouteService.reroute("cluster migrated to data tiers routing", Priority.NORMAL, new ActionListener() { + @Override + public void onResponse(Void ignored) {} + + @Override + public void onFailure(Exception e) { + logger.log( + MasterService.isPublishFailureException(e) ? Level.DEBUG : Level.WARN, + "unsuccessful reroute after migration to data tiers routing", + e + ); + } + }); MigratedEntities entities = migratedEntities.get(); listener.onResponse( new MigrateToDataTiersResponse( diff --git a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java index e39463141c777..b37eb8f99f52c 100644 --- a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java +++ b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java @@ -299,6 +299,7 @@ public void testOperatorControllerFromJSONContent() throws IOException { ReservedClusterStateService controller = new ReservedClusterStateService( clusterService, + null, List.of(new ReservedClusterSettingsAction(clusterSettings)) ); @@ -371,6 +372,7 @@ public void testOperatorControllerFromJSONContent() throws IOException { controller = new ReservedClusterStateService( clusterService, + null, List.of( new ReservedClusterSettingsAction(clusterSettings), new ReservedLifecycleAction(xContentRegistry(), client, licenseState) @@ -393,6 +395,7 @@ public void testOperatorControllerWithPluginPackage() { ReservedClusterStateService controller = new ReservedClusterStateService( clusterService, + null, List.of(new ReservedClusterSettingsAction(clusterSettings)) ); @@ -430,6 +433,7 @@ public void testOperatorControllerWithPluginPackage() { controller = new ReservedClusterStateService( clusterService, + null, List.of( new ReservedClusterSettingsAction(clusterSettings), new ReservedLifecycleAction(xContentRegistry(), client, licenseState) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java new file mode 100644 index 0000000000000..1578e03608e82 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.ClassRule; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class InferenceBaseRestTest extends ESRestTestCase { + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .setting("xpack.license.self_generated.type", "trial") + .setting("xpack.security.enabled", "true") + .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") + .user("x_pack_rest_user", "x-pack-test-password") + .build(); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @Override + protected Settings restClientSettings() { + String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + } + + static String mockServiceModelConfig() { + return """ + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + "temperature": 3 + } + } + """; + } + + protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { + String endpoint = Strings.format("_inference/%s/%s", taskType, modelId); + var request = new Request("PUT", endpoint); + request.setJsonEntity(modelConfig); + var response = client().performRequest(request); + assertOkOrCreated(response); + return entityAsMap(response); + } + + protected Map getModels(String modelId, TaskType taskType) throws IOException { + var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); + var request = new Request("GET", endpoint); + var response = client().performRequest(request); + assertOkOrCreated(response); + return entityAsMap(response); + } + + protected Map getAllModels() throws IOException { + var endpoint = Strings.format("_inference/_all"); + var request = new Request("GET", endpoint); + var response = client().performRequest(request); + assertOkOrCreated(response); + return entityAsMap(response); + } + + protected Map inferOnMockService(String modelId, TaskType taskType, List input) throws IOException { + var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); + var request = new Request("POST", endpoint); + + var bodyBuilder = new StringBuilder("{\"input\": ["); + for (var in : input) { + bodyBuilder.append('"').append(in).append('"').append(','); + } + // remove last comma + bodyBuilder.deleteCharAt(bodyBuilder.length() - 1); + bodyBuilder.append("]}"); + + request.setJsonEntity(bodyBuilder.toString()); + var response = client().performRequest(request); + assertOkOrCreated(response); + return entityAsMap(response); + } + + @SuppressWarnings("unchecked") + protected void assertNonEmptyInferenceResults(Map resultMap, int expectedNumberOfResults, TaskType taskType) { + if (taskType == TaskType.SPARSE_EMBEDDING) { + var results = (List>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString()); + assertThat(results, hasSize(expectedNumberOfResults)); + } else { + fail("test with task type [" + taskType + "] are not supported yet"); + } + } + + protected static void assertOkOrCreated(Response response) throws IOException { + int statusCode = response.getStatusLine().getStatusCode(); + // Once EntityUtils.toString(entity) is called the entity cannot be reused. + // Avoid that call with check here. + if (statusCode == 200 || statusCode == 201) { + return; + } + + String responseStr = EntityUtils.toString(response.getEntity()); + assertThat(responseStr, response.getStatusLine().getStatusCode(), anyOf(equalTo(200), equalTo(201))); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java new file mode 100644 index 0000000000000..61278fcae6d94 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; + +public class InferenceCrudIT extends InferenceBaseRestTest { + + @SuppressWarnings("unchecked") + public void testGet() throws IOException { + for (int i = 0; i < 5; i++) { + putModel("se_model_" + i, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + } + for (int i = 0; i < 4; i++) { + putModel("te_model_" + i, mockServiceModelConfig(), TaskType.TEXT_EMBEDDING); + } + + var getAllModels = (List>) getAllModels().get("models"); + assertThat(getAllModels, hasSize(9)); + + var getSparseModels = (List>) getModels("_all", TaskType.SPARSE_EMBEDDING).get("models"); + assertThat(getSparseModels, hasSize(5)); + for (var sparseModel : getSparseModels) { + assertEquals("sparse_embedding", sparseModel.get("task_type")); + } + + var getDenseModels = (List>) getModels("_all", TaskType.TEXT_EMBEDDING).get("models"); + assertThat(getDenseModels, hasSize(4)); + for (var denseModel : getDenseModels) { + assertEquals("text_embedding", denseModel.get("task_type")); + } + + var singleModel = (List>) getModels("se_model_1", TaskType.SPARSE_EMBEDDING).get("models"); + assertThat(singleModel, hasSize(1)); + assertEquals("se_model_1", singleModel.get(0).get("model_id")); + } + + public void testGetModelWithWrongTaskType() throws IOException { + putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING)); + assertThat( + e.getMessage(), + containsString("Requested task type [text_embedding] does not match the model's task type [sparse_embedding]") + ); + } + + @SuppressWarnings("unchecked") + public void testGetModelWithAnyTaskType() throws IOException { + String modelId = "sparse_embedding_model"; + putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + var singleModel = (List>) getModels(modelId, TaskType.ANY).get("models"); + System.out.println("MODEL" + singleModel); + assertEquals(modelId, singleModel.get(0).get("model_id")); + assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type")); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java index 5ed11958fc64e..f8abfd45a8566 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java @@ -7,67 +7,13 @@ package org.elasticsearch.xpack.inference; -import org.apache.http.util.EntityUtils; -import org.elasticsearch.client.Request; -import org.elasticsearch.client.Response; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.cluster.ElasticsearchCluster; -import org.elasticsearch.test.cluster.local.distribution.DistributionType; -import org.elasticsearch.test.rest.ESRestTestCase; -import org.junit.ClassRule; import java.io.IOException; import java.util.List; import java.util.Map; -import static org.hamcrest.Matchers.anyOf; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.emptyString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; - -public class MockInferenceServiceIT extends ESRestTestCase { - - @ClassRule - public static ElasticsearchCluster cluster = ElasticsearchCluster.local() - .distribution(DistributionType.DEFAULT) - .setting("xpack.license.self_generated.type", "trial") - .setting("xpack.security.enabled", "true") - .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") - .user("x_pack_rest_user", "x-pack-test-password") - .build(); - - @Override - protected String getTestRestCluster() { - return cluster.getHttpAddresses(); - } - - @Override - protected Settings restClientSettings() { - String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); - return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); - } - - static String mockServiceModelConfig() { - return Strings.format(""" - { - "service": "test_service", - "service_settings": { - "model": "my_model", - "api_key": "abc64" - }, - "task_settings": { - "temperature": 3 - } - } - """); - } +public class MockInferenceServiceIT extends InferenceBaseRestTest { @SuppressWarnings("unchecked") public void testMockService() throws IOException { @@ -84,7 +30,7 @@ public void testMockService() throws IOException { // The response is randomly generated, the input can be anything var inference = inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10))); - assertNonEmptyInferenceResults(inference, TaskType.SPARSE_EMBEDDING); + assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); } @SuppressWarnings("unchecked") @@ -99,9 +45,7 @@ public void testMockServiceWithMultipleInputs() throws IOException { List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15)) ); - var results = (List>) inference.get("result"); - assertThat(results, hasSize(3)); - assertNonEmptyInferenceResults(inference, TaskType.SPARSE_EMBEDDING); + assertNonEmptyInferenceResults(inference, 3, TaskType.SPARSE_EMBEDDING); } @SuppressWarnings("unchecked") @@ -119,63 +63,4 @@ public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOExcepti assertNull(putServiceSettings.get("api_key")); assertNotNull(putServiceSettings.get("model")); } - - private Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { - String endpoint = Strings.format("_inference/%s/%s", taskType, modelId); - var request = new Request("PUT", endpoint); - request.setJsonEntity(modelConfig); - var reponse = client().performRequest(request); - assertOkWithErrorMessage(reponse); - return entityAsMap(reponse); - } - - public Map getModels(String modelId, TaskType taskType) throws IOException { - var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); - var request = new Request("GET", endpoint); - var reponse = client().performRequest(request); - assertOkWithErrorMessage(reponse); - return entityAsMap(reponse); - } - - private Map inferOnMockService(String modelId, TaskType taskType, List input) throws IOException { - var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); - var request = new Request("POST", endpoint); - - var bodyBuilder = new StringBuilder("{\"input\": ["); - for (var in : input) { - bodyBuilder.append('"').append(in).append('"').append(','); - } - // remove last comma - bodyBuilder.deleteCharAt(bodyBuilder.length() - 1); - bodyBuilder.append("]}"); - - System.out.println("body_request:" + bodyBuilder); - request.setJsonEntity(bodyBuilder.toString()); - var reponse = client().performRequest(request); - assertOkWithErrorMessage(reponse); - return entityAsMap(reponse); - } - - @SuppressWarnings("unchecked") - protected void assertNonEmptyInferenceResults(Map resultMap, TaskType taskType) { - if (taskType == TaskType.SPARSE_EMBEDDING) { - var results = (List) resultMap.get("result"); - assertThat(results, not(empty())); - for (String result : results) { - assertThat(result, is(not(emptyString()))); - } - } else { - fail("test with task type [" + taskType + "] are not supported yet"); - } - } - - protected static void assertOkWithErrorMessage(Response response) throws IOException { - int statusCode = response.getStatusLine().getStatusCode(); - if (statusCode == 200 || statusCode == 201) { - return; - } - - String responseStr = EntityUtils.toString(response.getEntity()); - assertThat(responseStr, response.getStatusLine().getStatusCode(), anyOf(equalTo(200), equalTo(201))); - } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/build.gradle b/x-pack/plugin/inference/qa/test-service-plugin/build.gradle index 9020589f74a0c..031c7519154b1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/build.gradle +++ b/x-pack/plugin/inference/qa/test-service-plugin/build.gradle @@ -6,6 +6,7 @@ esplugin { name 'inference-service-test' description 'A mock inference service' classname 'org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin' + extendedPlugins = ['x-pack-inference'] } dependencies { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java new file mode 100644 index 0000000000000..eee6f68c20ff7 --- /dev/null +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java @@ -0,0 +1,324 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mock; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class TestInferenceServiceExtension implements InferenceServiceExtension { + @Override + public List getInferenceServiceFactories() { + return List.of(TestInferenceService::new); + } + + public static class TestInferenceService implements InferenceService { + private static final String NAME = "test_service"; + + public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} + + @Override + public String name() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @SuppressWarnings("unchecked") + private static Map getTaskSettingsMap(Map settings) { + Map taskSettingsMap; + // task settings are optional + if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) { + taskSettingsMap = (Map) settings.remove(ModelConfigurations.TASK_SETTINGS); + } else { + taskSettingsMap = Map.of(); + } + + return taskSettingsMap; + } + + @Override + @SuppressWarnings("unchecked") + public TestServiceModel parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platfromArchitectures + ) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); + } + + @Override + @SuppressWarnings("unchecked") + public TestServiceModel parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + var secretSettingsMap = (Map) secrets.remove(ModelSecrets.SECRET_SETTINGS); + + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + var secretSettings = TestSecretSettings.fromMap(secretSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); + } + + @Override + @SuppressWarnings("unchecked") + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); + + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + + var taskSettingsMap = getTaskSettingsMap(config); + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null); + } + + @Override + public void infer( + Model model, + List input, + Map taskSettings, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case ANY -> listener.onResponse(makeResults(input)); + case SPARSE_EMBEDDING -> listener.onResponse(makeResults(input)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + + private SparseEmbeddingResults makeResults(List input) { + var embeddings = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + var tokens = new ArrayList(); + for (int j = 0; j < 5; j++) { + tokens.add(new SparseEmbeddingResults.WeightedToken(Integer.toString(j), (float) j)); + } + embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false)); + } + return new SparseEmbeddingResults(embeddings); + } + + @Override + public void start(Model model, ActionListener listener) { + listener.onResponse(true); + } + + @Override + public void close() throws IOException {} + } + + public static class TestServiceModel extends Model { + + public TestServiceModel( + String modelId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); + } + + @Override + public TestServiceSettings getServiceSettings() { + return (TestServiceSettings) super.getServiceSettings(); + } + + @Override + public TestTaskSettings getTaskSettings() { + return (TestTaskSettings) super.getTaskSettings(); + } + + @Override + public TestSecretSettings getSecretSettings() { + return (TestSecretSettings) super.getSecretSettings(); + } + } + + public record TestServiceSettings(String model) implements ServiceSettings { + + static final String NAME = "test_service_settings"; + + public static TestServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = (String) map.remove("model"); + + if (model == null) { + validationException.addValidationError("missing model"); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new TestServiceSettings(model); + } + + public TestServiceSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("model", model); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(model); + } + } + + public record TestTaskSettings(Integer temperature) implements TaskSettings { + + static final String NAME = "test_task_settings"; + + public static TestTaskSettings fromMap(Map map) { + Integer temperature = (Integer) map.remove("temperature"); + return new TestTaskSettings(temperature); + } + + public TestTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalVInt()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(temperature); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (temperature != null) { + builder.field("temperature", temperature); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + } + + public record TestSecretSettings(String apiKey) implements SecretSettings { + + static final String NAME = "test_secret_settings"; + + public static TestSecretSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String apiKey = (String) map.remove("api_key"); + + if (apiKey == null) { + validationException.addValidationError("missing api_key"); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new TestSecretSettings(apiKey); + } + + public TestSecretSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(apiKey); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("api_key", apiKey); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + } +} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java index 4d8cb18e541ff..0345d7b6e5926 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java @@ -7,395 +7,34 @@ package org.elasticsearch.xpack.inference.mock; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.Set; -public class TestInferenceServicePlugin extends Plugin implements InferenceServicePlugin { +public class TestInferenceServicePlugin extends Plugin { @Override - public List getInferenceServiceFactories() { - return List.of(TestInferenceService::new, TestInferenceServiceClusterService::new); - } - - @Override - public List getInferenceServiceNamedWriteables() { + public List getNamedWriteables() { return List.of( - new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new), - new NamedWriteableRegistry.Entry(TaskSettings.class, TestTaskSettings.NAME, TestTaskSettings::new), - new NamedWriteableRegistry.Entry(SecretSettings.class, TestSecretSettings.NAME, TestSecretSettings::new) + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + TestInferenceServiceExtension.TestServiceSettings.NAME, + TestInferenceServiceExtension.TestServiceSettings::new + ), + new NamedWriteableRegistry.Entry( + TaskSettings.class, + TestInferenceServiceExtension.TestTaskSettings.NAME, + TestInferenceServiceExtension.TestTaskSettings::new + ), + new NamedWriteableRegistry.Entry( + SecretSettings.class, + TestInferenceServiceExtension.TestSecretSettings.NAME, + TestInferenceServiceExtension.TestSecretSettings::new + ) ); } - - public static class TestInferenceService extends TestInferenceServiceBase { - private static final String NAME = "test_service"; - - public TestInferenceService(InferenceServiceFactoryContext context) { - super(context); - } - - @Override - public String name() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - } - - public static class TestInferenceServiceClusterService extends TestInferenceServiceBase { - private static final String NAME = "test_service_in_cluster_service"; - - public TestInferenceServiceClusterService(InferenceServiceFactoryContext context) { - super(context); - } - - @Override - public boolean isInClusterService() { - return true; - } - - @Override - public String name() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - } - - public abstract static class TestInferenceServiceBase implements InferenceService { - - @SuppressWarnings("unchecked") - private static Map getTaskSettingsMap(Map settings) { - Map taskSettingsMap; - // task settings are optional - if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = (Map) settings.remove(ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - return taskSettingsMap; - } - - public TestInferenceServiceBase(InferenceServicePlugin.InferenceServiceFactoryContext context) { - - } - - @Override - @SuppressWarnings("unchecked") - public TestServiceModel parseRequestConfig( - String modelId, - TaskType taskType, - Map config, - Set platfromArchitectures - ) { - var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); - var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap); - - var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); - - return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); - } - - @Override - @SuppressWarnings("unchecked") - public TestServiceModel parsePersistedConfig( - String modelId, - TaskType taskType, - Map config, - Map secrets - ) { - var serviceSettingsMap = (Map) config.remove(ModelConfigurations.SERVICE_SETTINGS); - var secretSettingsMap = (Map) secrets.remove(ModelSecrets.SECRET_SETTINGS); - - var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); - var secretSettings = TestSecretSettings.fromMap(secretSettingsMap); - - var taskSettingsMap = getTaskSettingsMap(config); - var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); - - return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); - } - - @Override - public void infer( - Model model, - List input, - Map taskSettings, - ActionListener listener - ) { - switch (model.getConfigurations().getTaskType()) { - case SPARSE_EMBEDDING -> { - var strings = new ArrayList(); - for (int i = 0; i < input.size(); i++) { - strings.add(Integer.toString(i)); - } - - listener.onResponse(new TestResults(strings)); - } - default -> listener.onFailure( - new ElasticsearchStatusException( - TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), - RestStatus.BAD_REQUEST - ) - ); - } - - } - - @Override - public void start(Model model, ActionListener listener) { - listener.onResponse(true); - } - - @Override - public void close() throws IOException {} - } - - public static class TestServiceModel extends Model { - - public TestServiceModel( - String modelId, - TaskType taskType, - String service, - TestServiceSettings serviceSettings, - TestTaskSettings taskSettings, - TestSecretSettings secretSettings - ) { - super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); - } - - @Override - public TestServiceSettings getServiceSettings() { - return (TestServiceSettings) super.getServiceSettings(); - } - - @Override - public TestTaskSettings getTaskSettings() { - return (TestTaskSettings) super.getTaskSettings(); - } - - @Override - public TestSecretSettings getSecretSettings() { - return (TestSecretSettings) super.getSecretSettings(); - } - } - - public record TestServiceSettings(String model) implements ServiceSettings { - - private static final String NAME = "test_service_settings"; - - public static TestServiceSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); - - String model = (String) map.remove("model"); - - if (model == null) { - validationException.addValidationError("missing model"); - } - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new TestServiceSettings(model); - } - - public TestServiceSettings(StreamInput in) throws IOException { - this(in.readString()); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("model", model); - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(model); - } - } - - public record TestTaskSettings(Integer temperature) implements TaskSettings { - - private static final String NAME = "test_task_settings"; - - public static TestTaskSettings fromMap(Map map) { - Integer temperature = (Integer) map.remove("temperature"); - return new TestTaskSettings(temperature); - } - - public TestTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalVInt()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalVInt(temperature); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (temperature != null) { - builder.field("temperature", temperature); - } - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - } - - public record TestSecretSettings(String apiKey) implements SecretSettings { - - private static final String NAME = "test_secret_settings"; - - public static TestSecretSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); - - String apiKey = (String) map.remove("api_key"); - - if (apiKey == null) { - validationException.addValidationError("missing api_key"); - } - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new TestSecretSettings(apiKey); - } - - public TestSecretSettings(StreamInput in) throws IOException { - this(in.readString()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(apiKey); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("api_key", apiKey); - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests - } - } - - private static class TestResults implements InferenceServiceResults, InferenceResults { - - private static final String RESULTS_FIELD = "result"; - private List result; - - TestResults(List result) { - this.result = result; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RESULTS_FIELD, result); - return builder; - } - - @Override - public String getWriteableName() { - return "test_result"; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeStringCollection(result); - } - - @Override - public String getResultsField() { - return RESULTS_FIELD; - } - - @Override - public List transformToLegacyFormat() { - return List.of(this); - } - - @Override - public Map asMap() { - return Map.of("result", result); - } - - @Override - public Map asMap(String outputField) { - return Map.of(outputField, result); - } - - @Override - public Object predictedValue() { - return result; - } - } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension new file mode 100644 index 0000000000000..019a6dad7be85 --- /dev/null +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension @@ -0,0 +1 @@ +org.elasticsearch.xpack.inference.mock.TestInferenceServiceExtension diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 520a4cc5c0526..50647ca328b23 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -8,20 +8,23 @@ package org.elasticsearch.xpack.inference.integration; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.UnparsedModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; @@ -31,13 +34,21 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.stream.Collectors; +import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -59,7 +70,7 @@ protected Collection> getPlugins() { public void testStoreModel() throws Exception { String modelId = "test-store-model"; - Model model = buildModelConfig(modelId, ElserMlNodeService.NAME, TaskType.SPARSE_EMBEDDING); + Model model = buildElserModelConfig(modelId, TaskType.SPARSE_EMBEDDING); AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -90,7 +101,7 @@ public void testStoreModelWithUnknownFields() throws Exception { public void testGetModel() throws Exception { String modelId = "test-get-model"; - Model model = buildModelConfig(modelId, ElserMlNodeService.NAME, TaskType.SPARSE_EMBEDDING); + Model model = buildElserModelConfig(modelId, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -98,27 +109,26 @@ public void testGetModel() throws Exception { assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.getUnparsedModelMap(modelId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(modelId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); - UnparsedModel unparsedModel = UnparsedModel.unparsedModelFromMap(modelHolder.get().config(), modelHolder.get().secrets()); - assertEquals(model.getConfigurations().getService(), unparsedModel.service()); + assertEquals(model.getConfigurations().getService(), modelHolder.get().service()); - var elserService = new ElserMlNodeService(new InferenceServicePlugin.InferenceServiceFactoryContext(mock(Client.class))); - ElserMlNodeModel roundTripModel = elserService.parsePersistedConfig( - unparsedModel.modelId(), - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() + var elserService = new ElserMlNodeService(new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class))); + ElserMlNodeModel roundTripModel = elserService.parsePersistedConfigWithSecrets( + modelHolder.get().modelId(), + modelHolder.get().taskType(), + modelHolder.get().settings(), + modelHolder.get().secrets() ); assertEquals(model, roundTripModel); } public void testStoreModelFailsWhenModelExists() throws Exception { String modelId = "test-put-trained-model-config-exists"; - Model model = buildModelConfig(modelId, ElserMlNodeService.NAME, TaskType.SPARSE_EMBEDDING); + Model model = buildElserModelConfig(modelId, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -140,7 +150,7 @@ public void testStoreModelFailsWhenModelExists() throws Exception { public void testDeleteModel() throws Exception { // put models for (var id : new String[] { "model1", "model2", "model3" }) { - Model model = buildModelConfig(id, ElserMlNodeService.NAME, TaskType.SPARSE_EMBEDDING); + Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); @@ -155,19 +165,115 @@ public void testDeleteModel() throws Exception { // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.getUnparsedModelMap("model1", listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); assertThat(exceptionHolder.get().getMessage(), containsString("Model not found [model1]")); } - private Model buildModelConfig(String modelId, String service, TaskType taskType) { - return switch (service) { - case ElserMlNodeService.NAME -> ElserMlNodeServiceTests.randomModelConfig(modelId, taskType); - default -> throw new IllegalArgumentException("unknown service " + service); - }; + public void testGetModelsByTaskType() throws InterruptedException { + var service = "foo"; + var sparseAndTextEmbeddingModels = new ArrayList(); + sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.SPARSE_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.TEXT_EMBEDDING, service)); + sparseAndTextEmbeddingModels.add(createModel(randomAlphaOfLength(5), TaskType.TEXT_EMBEDDING, service)); + + for (var model : sparseAndTextEmbeddingModels) { + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + } + + AtomicReference exceptionHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(3)); + var sparseIds = sparseAndTextEmbeddingModels.stream() + .filter(m -> m.getConfigurations().getTaskType() == TaskType.SPARSE_EMBEDDING) + .map(Model::getModelId) + .collect(Collectors.toSet()); + modelHolder.get().forEach(m -> { + assertTrue(sparseIds.contains(m.modelId())); + assertThat(m.secrets().keySet(), empty()); + }); + + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(2)); + var denseIds = sparseAndTextEmbeddingModels.stream() + .filter(m -> m.getConfigurations().getTaskType() == TaskType.TEXT_EMBEDDING) + .map(Model::getModelId) + .collect(Collectors.toSet()); + modelHolder.get().forEach(m -> { + assertTrue(denseIds.contains(m.modelId())); + assertThat(m.secrets().keySet(), empty()); + }); + } + + public void testGetAllModels() throws InterruptedException { + var service = "foo"; + var createdModels = new ArrayList(); + int modelCount = randomIntBetween(30, 100); + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + for (int i = 0; i < modelCount; i++) { + var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); + createdModels.add(model); + + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + } + + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(modelCount)); + var getAllModels = modelHolder.get(); + + // sort in the same order as the returned models + createdModels.sort(Comparator.comparing(Model::getModelId)); + for (int i = 0; i < modelCount; i++) { + assertEquals(createdModels.get(i).getModelId(), getAllModels.get(i).modelId()); + assertEquals(createdModels.get(i).getTaskType(), getAllModels.get(i).taskType()); + assertEquals(createdModels.get(i).getConfigurations().getService(), getAllModels.get(i).service()); + assertThat(getAllModels.get(i).secrets().keySet(), empty()); + } + } + + @SuppressWarnings("unchecked") + public void testGetModelWithSecrets() throws InterruptedException { + var service = "foo"; + var modelId = "model-with-secrets"; + var secret = "abc"; + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + var modelWithSecrets = createModelWithSecrets(modelId, randomFrom(TaskType.values()), service, secret); + blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(modelId, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); + var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); + assertThat(secretSettings.get("secret"), equalTo(secret)); + + // get model without secrets + blockingCall(listener -> modelRegistry.getModel(modelId, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get().secrets().keySet(), empty()); + } + + private Model buildElserModelConfig(String modelId, TaskType taskType) { + return ElserMlNodeServiceTests.randomModelConfig(modelId, taskType); } protected void blockingCall(Consumer> function, AtomicReference response, AtomicReference error) @@ -197,6 +303,112 @@ private static Model buildModelWithUnknownField(String modelId) { ); } + public static Model createModel(String modelId, TaskType taskType, String service) { + return new Model(new ModelConfigurations(modelId, taskType, service, new TestModelOfAnyKind.TestModelServiceSettings())); + } + + public static Model createModelWithSecrets(String modelId, TaskType taskType, String service, String secret) { + return new Model( + new ModelConfigurations(modelId, taskType, service, new TestModelOfAnyKind.TestModelServiceSettings()), + new ModelSecrets(new TestModelOfAnyKind.TestSecretSettings(secret)) + ); + } + + private static class TestModelOfAnyKind extends ModelConfigurations { + + record TestModelServiceSettings() implements ServiceSettings { + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return "test_service_settings"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + } + + record TestTaskSettings() implements TaskSettings { + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return "test_task_settings"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + } + + record TestSecretSettings(String key) implements SecretSettings { + @Override + public String getWriteableName() { + return "test_secrets"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(key); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("secret", key); + builder.endObject(); + return builder; + } + } + + TestModelOfAnyKind(String modelId, TaskType taskType, String service) { + super(modelId, taskType, service, new TestModelServiceSettings(), new TestTaskSettings()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("unknown_field", "foo"); + builder.field(MODEL_ID, getModelId()); + builder.field(TaskType.NAME, getTaskType().toString()); + builder.field(SERVICE, getService()); + builder.field(SERVICE_SETTINGS, getServiceSettings()); + builder.field(TASK_SETTINGS, getTaskSettings()); + builder.endObject(); + return builder; + } + } + private static class ModelWithUnknownField extends ModelConfigurations { ModelWithUnknownField( diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 801b0a1cd755c..87f623bdfe5cc 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -21,7 +21,9 @@ exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; exports org.elasticsearch.xpack.inference.rest; - exports org.elasticsearch.xpack.inference.results; exports org.elasticsearch.xpack.inference.services; + exports org.elasticsearch.xpack.inference.services.elser; + exports org.elasticsearch.xpack.inference.services.huggingface.elser; + exports org.elasticsearch.xpack.inference.services.openai; exports org.elasticsearch.xpack.inference; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 0ba7ca1d49150..092b1200fb80a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -14,9 +14,9 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.xpack.inference.results.LegacyTextEmbeddingResults; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserSecretSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 476f19a286d53..7e7f2c9e05680 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,8 +21,10 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.indices.SystemIndexDescriptor; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.plugins.ActionPlugin; -import org.elasticsearch.plugins.InferenceServicePlugin; +import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -30,10 +32,10 @@ import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.xpack.inference.action.DeleteInferenceModelAction; -import org.elasticsearch.xpack.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportInferenceAction; @@ -53,13 +55,14 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, InferenceServicePlugin, SystemIndexPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin { public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; @@ -69,6 +72,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, InferenceSe private final SetOnce httpFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); + private final SetOnce inferenceServiceRegistry = new SetOnce<>(); + private List inferenceServiceExtensions; + public InferencePlugin(Settings settings) { this.settings = settings; } @@ -117,7 +123,39 @@ public Collection createComponents(PluginServices services) { httpFactory.set(httpRequestSenderFactory); ModelRegistry modelRegistry = new ModelRegistry(services.client()); - return List.of(modelRegistry); + + if (inferenceServiceExtensions == null) { + inferenceServiceExtensions = new ArrayList<>(); + } + var inferenceServices = new ArrayList<>(inferenceServiceExtensions); + inferenceServices.add(this::getInferenceServiceFactories); + + var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); + var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); + registry.init(services.client()); + inferenceServiceRegistry.set(registry); + + return List.of(modelRegistry, registry); + } + + @Override + public void loadExtensions(ExtensionLoader loader) { + inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class); + } + + public List getInferenceServiceFactories() { + return List.of( + ElserMlNodeService::new, + context -> new HuggingFaceElserService(httpFactory, serviceComponents), + context -> new OpenAiService(httpFactory, serviceComponents) + ); + } + + @Override + public List getNamedWriteables() { + var entries = new ArrayList(); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return entries; } @Override @@ -182,20 +220,6 @@ public String getFeatureDescription() { return "Inference plugin for managing inference services and inference"; } - @Override - public List getInferenceServiceFactories() { - return List.of( - ElserMlNodeService::new, - context -> new HuggingFaceElserService(httpFactory, serviceComponents), - context -> new OpenAiService(httpFactory, serviceComponents) - ); - } - - @Override - public List getInferenceServiceNamedWriteables() { - return InferenceNamedWriteablesProvider.getNamedWriteables(); - } - @Override public void close() { var serviceComponentsRef = serviceComponents.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index 4305ff5a7b631..88a364d1de8fe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -21,6 +21,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 90fe9667c33aa..52fc115d4a4a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -9,18 +9,26 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.inference.UnparsedModel; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Executor; public class TransportGetInferenceModelAction extends HandledTransportAction< GetInferenceModelAction.Request, @@ -28,11 +36,13 @@ public class TransportGetInferenceModelAction extends HandledTransportAction< private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; + private final Executor executor; @Inject public TransportGetInferenceModelAction( TransportService transportService, ActionFilters actionFilters, + ThreadPool threadPool, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry ) { @@ -45,6 +55,7 @@ public TransportGetInferenceModelAction( ); this.modelRegistry = modelRegistry; this.serviceRegistry = serviceRegistry; + this.executor = threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME); } @Override @@ -53,8 +64,19 @@ protected void doExecute( GetInferenceModelAction.Request request, ActionListener listener ) { - modelRegistry.getUnparsedModelMap(request.getModelId(), ActionListener.wrap(modelConfigMap -> { - var unparsedModel = UnparsedModel.unparsedModelFromMap(modelConfigMap.config(), modelConfigMap.secrets()); + boolean modelIdIsWildCard = Strings.isAllOrWildcard(request.getModelId()); + + if (request.getTaskType() == TaskType.ANY && modelIdIsWildCard) { + getAllModels(listener); + } else if (modelIdIsWildCard) { + getModelsByTaskType(request.getTaskType(), listener); + } else { + getSingleModel(request.getModelId(), request.getTaskType(), listener); + } + } + + private void getSingleModel(String modelId, TaskType requestedTaskType, ActionListener listener) { + modelRegistry.getModel(modelId, ActionListener.wrap(unparsedModel -> { var service = serviceRegistry.getService(unparsedModel.service()); if (service.isEmpty()) { listener.onFailure( @@ -67,9 +89,56 @@ protected void doExecute( ); return; } - var model = service.get() - .parsePersistedConfig(unparsedModel.modelId(), unparsedModel.taskType(), unparsedModel.settings(), unparsedModel.secrets()); + + if (requestedTaskType.isAnyOrSame(unparsedModel.taskType()) == false) { + listener.onFailure( + new ElasticsearchStatusException( + "Requested task type [{}] does not match the model's task type [{}]", + RestStatus.BAD_REQUEST, + requestedTaskType, + unparsedModel.taskType() + ) + ); + return; + } + + var model = service.get().parsePersistedConfig(unparsedModel.modelId(), unparsedModel.taskType(), unparsedModel.settings()); listener.onResponse(new GetInferenceModelAction.Response(List.of(model.getConfigurations()))); }, listener::onFailure)); } + + private void getAllModels(ActionListener listener) { + modelRegistry.getAllModels( + ActionListener.wrap(models -> executor.execute(ActionRunnable.supply(listener, () -> parseModels(models))), listener::onFailure) + ); + } + + private void getModelsByTaskType(TaskType taskType, ActionListener listener) { + modelRegistry.getModelsByTaskType( + taskType, + ActionListener.wrap(models -> executor.execute(ActionRunnable.supply(listener, () -> parseModels(models))), listener::onFailure) + ); + } + + private GetInferenceModelAction.Response parseModels(List unparsedModels) { + var parsedModels = new ArrayList(); + + for (var unparsedModel : unparsedModels) { + var service = serviceRegistry.getService(unparsedModel.service()); + if (service.isEmpty()) { + throw new ElasticsearchStatusException( + "Unknown service [{}] for model [{}]. ", + RestStatus.INTERNAL_SERVER_ERROR, + unparsedModel.service(), + unparsedModel.modelId() + ); + } + parsedModels.add( + service.get() + .parsePersistedConfig(unparsedModel.modelId(), unparsedModel.taskType(), unparsedModel.settings()) + .getConfigurations() + ); + } + return new GetInferenceModelAction.Response(parsedModels); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 7718739420cf1..7fb86763ad534 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -19,7 +19,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.inference.UnparsedModel; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { @@ -42,8 +42,7 @@ public TransportInferenceAction( @Override protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - ActionListener getModelListener = ActionListener.wrap(modelConfigMap -> { - var unparsedModel = UnparsedModel.unparsedModelFromMap(modelConfigMap.config(), modelConfigMap.secrets()); + ActionListener getModelListener = ActionListener.wrap(unparsedModel -> { var service = serviceRegistry.getService(unparsedModel.service()); if (service.isEmpty()) { listener.onFailure( @@ -57,7 +56,8 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe return; } - if (request.getTaskType() != unparsedModel.taskType()) { + if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { + // not the wildcard task type and not the model task type listener.onFailure( new ElasticsearchStatusException( "Incompatible task_type, the requested type [{}] does not match the model type [{}]", @@ -70,11 +70,16 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe } var model = service.get() - .parsePersistedConfig(unparsedModel.modelId(), unparsedModel.taskType(), unparsedModel.settings(), unparsedModel.secrets()); + .parsePersistedConfigWithSecrets( + unparsedModel.modelId(), + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ); inferOnService(model, request, service.get(), listener); }, listener::onFailure); - modelRegistry.getUnparsedModelMap(request.getModelId(), getModelListener); + modelRegistry.getModelWithSecrets(request.getModelId(), getModelListener); } private void inferOnService( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 569d4e023928b..f6bb90d701a4a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -34,6 +34,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java index 566ca9ff1351f..fab22dce889a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java @@ -13,8 +13,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index 60b568678987d..c301ab2194415 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -13,8 +13,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.results.TextEmbeddingResults; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 2937d4616571a..05c664f7ceeea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -30,8 +30,11 @@ import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -39,9 +42,12 @@ import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; +import org.elasticsearch.xpack.inference.services.MapParsingUtils; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; @@ -51,14 +57,47 @@ public class ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} + /** + * Semi parsed model where model id, task type and service + * are known but the settings are not parsed. + */ + public record UnparsedModel( + String modelId, + TaskType taskType, + String service, + Map settings, + Map secrets + ) { + + public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); + } + String modelId = MapParsingUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); + String service = MapParsingUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = MapParsingUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + } + } + + private static final String TASK_TYPE_FIELD = "task_type"; + private static final String MODEL_ID_FIELD = "model_id"; private static final Logger logger = LogManager.getLogger(ModelRegistry.class); + private final OriginSettingClient client; public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } - public void getUnparsedModelMap(String modelId, ActionListener listener) { + /** + * Get a model with its secret settings + * @param modelId Model to get + * @param listener Model listener + */ + public void getModelWithSecrets(String modelId, ActionListener listener) { ActionListener searchListener = ActionListener.wrap(searchResponse -> { // There should be a hit for the configurations and secrets if (searchResponse.getHits().getHits().length == 0) { @@ -67,7 +106,7 @@ public void getUnparsedModelMap(String modelId, ActionListener l } var hits = searchResponse.getHits().getHits(); - listener.onResponse(createModelConfigMap(hits, modelId)); + listener.onResponse(UnparsedModel.unparsedModelFromMap(createModelConfigMap(hits, modelId))); }, listener::onFailure); @@ -80,6 +119,111 @@ public void getUnparsedModelMap(String modelId, ActionListener l client.search(modelSearch, searchListener); } + /** + * Get a model. + * Secret settings are not included + * @param modelId Model to get + * @param listener Model listener + */ + public void getModel(String modelId, ActionListener listener) { + ActionListener searchListener = ActionListener.wrap(searchResponse -> { + // There should be a hit for the configurations and secrets + if (searchResponse.getHits().getHits().length == 0) { + listener.onFailure(new ResourceNotFoundException("Model not found [{}]", modelId)); + return; + } + + var hits = searchResponse.getHits().getHits(); + var modelConfigs = parseHitsAsModels(hits).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + assert modelConfigs.size() == 1; + listener.onResponse(modelConfigs.get(0)); + + }, listener::onFailure); + + QueryBuilder queryBuilder = documentIdQuery(modelId); + SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) + .setQuery(queryBuilder) + .setSize(1) + .setTrackTotalHits(false) + .request(); + + client.search(modelSearch, searchListener); + } + + /** + * Get all models of a particular task type. + * Secret settings are not included + * @param taskType The task type + * @param listener Models listener + */ + public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { + ActionListener searchListener = ActionListener.wrap(searchResponse -> { + // Not an error if no models of this task_type + if (searchResponse.getHits().getHits().length == 0) { + listener.onResponse(List.of()); + return; + } + + var hits = searchResponse.getHits().getHits(); + var modelConfigs = parseHitsAsModels(hits).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + listener.onResponse(modelConfigs); + + }, listener::onFailure); + + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString())); + + SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) + .setQuery(queryBuilder) + .setSize(10_000) + .setTrackTotalHits(false) + .addSort(MODEL_ID_FIELD, SortOrder.ASC) + .request(); + + client.search(modelSearch, searchListener); + } + + /** + * Get all models. + * Secret settings are not included + * @param listener Models listener + */ + public void getAllModels(ActionListener> listener) { + ActionListener searchListener = ActionListener.wrap(searchResponse -> { + // Not an error if no models of this task_type + if (searchResponse.getHits().getHits().length == 0) { + listener.onResponse(List.of()); + return; + } + + var hits = searchResponse.getHits().getHits(); + var modelConfigs = parseHitsAsModels(hits).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + listener.onResponse(modelConfigs); + + }, listener::onFailure); + + // In theory the index should only contain model config documents + // and a match all query would be sufficient. But just in case the + // index has been polluted return only docs with a task_type field + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.existsQuery(TASK_TYPE_FIELD)); + + SearchRequest modelSearch = client.prepareSearch(InferenceIndex.INDEX_PATTERN) + .setQuery(queryBuilder) + .setSize(10_000) + .setTrackTotalHits(false) + .addSort(MODEL_ID_FIELD, SortOrder.ASC) + .request(); + + client.search(modelSearch, searchListener); + } + + private List parseHitsAsModels(SearchHit[] hits) { + var modelConfigs = new ArrayList(); + for (var hit : hits) { + modelConfigs.add(new ModelConfigMap(hit.getSourceAsMap(), Map.of())); + } + return modelConfigs; + } + private ModelConfigMap createModelConfigMap(SearchHit[] hits, String modelId) { Map mappedHits = Arrays.stream(hits).collect(Collectors.toMap(hit -> { if (hit.getIndex().startsWith(InferenceIndex.INDEX_NAME)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteInferenceModelAction.java index 74050d4b32e89..184b310a9f829 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestDeleteInferenceModelAction.java @@ -11,7 +11,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.inference.action.DeleteInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java index f57c800bd5bdc..ce291bcf006ae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestGetInferenceModelAction.java @@ -8,12 +8,12 @@ package org.elasticsearch.xpack.inference.rest; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.GET; @@ -26,13 +26,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(GET, "_inference/{task_type}/{model_id}")); + return List.of(new Route(GET, "_inference/{task_type}/{model_id}"), new Route(GET, "_inference/_all")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String taskType = restRequest.param("task_type"); - String modelId = restRequest.param("model_id"); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { + String modelId = null; + TaskType taskType = null; + if (restRequest.hasParam("task_type") == false && restRequest.hasParam("model_id") == false) { + // _all models request + modelId = "_all"; + taskType = TaskType.ANY; + } else { + taskType = TaskType.fromStringOrStatusException(restRequest.param("task_type")); + modelId = restRequest.param("model_id"); + } var request = new GetInferenceModelAction.Request(modelId, taskType); return channel -> client.execute(GetInferenceModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index 9d7a0d331b2b3..beecf75da38ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -11,7 +11,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java index cf0eb857feba9..1199cf5688fcc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java @@ -11,7 +11,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java index 20bea7f1347b3..45bbddc92f135 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java @@ -61,9 +61,12 @@ public static Map removeFromMapOrThrowIfNull(Map return value; } - @SuppressWarnings("unchecked") - public static Map removeFromMap(Map sourceMap, String fieldName) { - return (Map) sourceMap.remove(fieldName); + public static String removeStringOrThrowIfNull(Map sourceMap, String key) { + String value = removeAsType(sourceMap, key, String.class); + if (value == null) { + throw new ElasticsearchStatusException("Missing required field [{}]", RestStatus.BAD_REQUEST, key); + } + return value; } public static void throwIfNotEmptyMap(Map settingsMap, String serviceName) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index f1fab447ec757..048920356aca0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -14,17 +14,17 @@ import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.List; @@ -52,7 +52,7 @@ public class ElserMlNodeService implements InferenceService { private final OriginSettingClient client; - public ElserMlNodeService(InferenceServicePlugin.InferenceServiceFactoryContext context) { + public ElserMlNodeService(InferenceServiceExtension.InferenceServiceFactoryContext context) { this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN); } @@ -100,12 +100,17 @@ public ElserMlNodeModel parseRequestConfig( } @Override - public ElserMlNodeModel parsePersistedConfig( + public ElserMlNodeModel parsePersistedConfigWithSecrets( String modelId, TaskType taskType, Map config, Map secrets ) { + return parsePersistedConfig(modelId, taskType, config); + } + + @Override + public ElserMlNodeModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); var serviceSettingsBuilder = ElserMlNodeServiceSettings.fromMap(serviceSettingsMap); @@ -160,7 +165,7 @@ public void start(Model model, ActionListener listener) { public void infer(Model model, List input, Map taskSettings, ActionListener listener) { // No task settings to override with requestTaskSettings - if (model.getConfigurations().getTaskType() != TaskType.SPARSE_EMBEDDING) { + if (TaskType.SPARSE_EMBEDDING.isAnyOrSame(model.getConfigurations().getTaskType()) == false) { listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), NAME), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 3aaa122e93fe9..8c978112c4ec3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -72,7 +72,7 @@ public HuggingFaceElserModel parseRequestConfig( } @Override - public HuggingFaceElserModel parsePersistedConfig( + public HuggingFaceElserModel parsePersistedConfigWithSecrets( String modelId, TaskType taskType, Map config, @@ -87,6 +87,14 @@ public HuggingFaceElserModel parsePersistedConfig( return new HuggingFaceElserModel(modelId, taskType, NAME, serviceSettings, secretSettings); } + @Override + public HuggingFaceElserModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + HuggingFaceElserServiceSettings serviceSettings = HuggingFaceElserServiceSettings.fromMap(serviceSettingsMap); + + return new HuggingFaceElserModel(modelId, taskType, NAME, serviceSettings, null); + } + @Override public void infer(Model model, List input, Map taskSettings, ActionListener listener) { if (model.getConfigurations().getTaskType() != TaskType.SPARSE_EMBEDDING) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 1d2d123432ab8..0a7ae147d13d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -95,7 +95,12 @@ private OpenAiModel createModel( } @Override - public OpenAiModel parsePersistedConfig(String modelId, TaskType taskType, Map config, Map secrets) { + public OpenAiModel parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); @@ -118,6 +123,27 @@ public OpenAiModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + OpenAiModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + format("Failed to parse stored model [%s] for [%s] service, please delete and add the service again", modelId, NAME) + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + return model; + } + @Override public void infer(Model model, List input, Map taskSettings, ActionListener listener) { init(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java index 0b30dc9021038..091c11a480c0d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java @@ -10,11 +10,12 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; public class GetInferenceModelRequestTests extends AbstractWireSerializingTestCase { public static GetInferenceModelAction.Request randomTestInstance() { - return new GetInferenceModelAction.Request(randomAlphaOfLength(8), randomFrom(TaskType.values()).toString()); + return new GetInferenceModelAction.Request(randomAlphaOfLength(8), randomFrom(TaskType.values())); } @Override @@ -30,10 +31,10 @@ protected GetInferenceModelAction.Request createTestInstance() { @Override protected GetInferenceModelAction.Request mutateInstance(GetInferenceModelAction.Request instance) { return switch (randomIntBetween(0, 1)) { - case 0 -> new GetInferenceModelAction.Request(instance.getModelId() + "foo", instance.getTaskType().toString()); + case 0 -> new GetInferenceModelAction.Request(instance.getModelId() + "foo", instance.getTaskType()); case 1 -> { var nextTaskType = TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length]; - yield new GetInferenceModelAction.Request(instance.getModelId(), nextTaskType.toString()); + yield new GetInferenceModelAction.Request(instance.getModelId(), nextTaskType); } default -> throw new UnsupportedOperationException(); }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelResponseTests.java index 472e4123c52e6..72f6f43126f7c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelResponseTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index d263cf8c776ea..aa540694ba564 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 515b6c268d0af..759411cec1212 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; @@ -25,7 +26,7 @@ import static org.elasticsearch.TransportVersions.INFERENCE_SERVICE_RESULTS_ADDED; import static org.elasticsearch.TransportVersions.ML_INFERENCE_OPENAI_ADDED; import static org.elasticsearch.TransportVersions.ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED; -import static org.elasticsearch.xpack.inference.action.InferenceAction.Response.transformToServiceResults; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Response.transformToServiceResults; public class InferenceActionResponseTests extends AbstractBWCWireSerializationTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index 9aefea9a942db..bdbca6426b601 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelResponseTests.java index 0a2ad4699cca8..89bd0247a9ccf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelResponseTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java index ce94bfceed4fb..606e0cc83f451 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java @@ -12,8 +12,8 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentEOFException; import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java index a3ec162b05ec8..56d8171640b53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java @@ -10,8 +10,8 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.results.TextEmbeddingResults; import java.io.IOException; import java.nio.charset.StandardCharsets; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 3b3134fe3d92e..b7d491bf54ddc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -18,8 +18,11 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESTestCase; @@ -29,12 +32,14 @@ import org.junit.After; import org.junit.Before; +import java.nio.ByteBuffer; import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -62,8 +67,8 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); - registry.getUnparsedModelMap("1", listener); + var listener = new PlainActionFuture(); + registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.getMessage(), is("Model not found [1]")); @@ -76,8 +81,8 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); - registry.getUnparsedModelMap("1", listener); + var listener = new PlainActionFuture(); + registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -93,8 +98,8 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); - registry.getUnparsedModelMap("1", listener); + var listener = new PlainActionFuture(); + registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -110,8 +115,8 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); - registry.getUnparsedModelMap("1", listener); + var listener = new PlainActionFuture(); + registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -120,21 +125,69 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind ); } - public void testGetUnparsedModelMap_ReturnsModelConfigMap_WhenBothInferenceAndSecretsHitsAreFound() { + public void testGetModelWithSecrets() { var client = mockClient(); + String config = """ + { + "model_id": "1", + "task_type": "sparse_embedding", + "service": "foo" + } + """; + String secrets = """ + { + "api_key": "secret" + } + """; + var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); + inferenceHit.sourceRef(BytesReference.fromByteBuffer(ByteBuffer.wrap(Strings.toUTF8Bytes(config)))); var inferenceSecretsHit = SearchHit.createFromMap(Map.of("_index", ".secrets-inference")); + inferenceSecretsHit.sourceRef(BytesReference.fromByteBuffer(ByteBuffer.wrap(Strings.toUTF8Bytes(secrets)))); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit })); var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); - registry.getUnparsedModelMap("1", listener); + var listener = new PlainActionFuture(); + registry.getModelWithSecrets("1", listener); + + var modelConfig = listener.actionGet(TIMEOUT); + assertEquals("1", modelConfig.modelId()); + assertEquals("foo", modelConfig.service()); + assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); + assertThat(modelConfig.settings().keySet(), empty()); + assertThat(modelConfig.secrets().keySet(), hasSize(1)); + assertEquals("secret", modelConfig.secrets().get("api_key")); + } + + public void testGetModelNoSecrets() { + var client = mockClient(); + String config = """ + { + "model_id": "1", + "task_type": "sparse_embedding", + "service": "foo" + } + """; + + var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); + inferenceHit.sourceRef(BytesReference.fromByteBuffer(ByteBuffer.wrap(Strings.toUTF8Bytes(config)))); + + mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); + + var registry = new ModelRegistry(client); + + var listener = new PlainActionFuture(); + registry.getModel("1", listener); + registry.getModel("1", listener); var modelConfig = listener.actionGet(TIMEOUT); - assertThat(modelConfig.config(), nullValue()); - assertThat(modelConfig.secrets(), nullValue()); + assertEquals("1", modelConfig.modelId()); + assertEquals("foo", modelConfig.service()); + assertEquals(TaskType.SPARSE_EMBEDDING, modelConfig.taskType()); + assertThat(modelConfig.settings().keySet(), empty()); + assertThat(modelConfig.secrets().keySet(), empty()); } public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java index 6553f1e7f8ae3..605411343533f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 9ab33ef777445..0a8bfd20caaf1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index fabb6c3de0fbc..71d14e09872fd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java index 56a592a490712..f8480709a3e40 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java @@ -9,10 +9,10 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.test.ESTestCase; import java.util.Collections; @@ -127,7 +127,12 @@ public void testParseConfigStrictWithUnknownSettings() { containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") ); } else { - var parsed = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Collections.emptyMap()); + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); } } @@ -158,7 +163,12 @@ public void testParseConfigStrictWithUnknownSettings() { containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") ); } else { - var parsed = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Collections.emptyMap()); + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); } } @@ -190,7 +200,12 @@ public void testParseConfigStrictWithUnknownSettings() { containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") ); } else { - var parsed = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Collections.emptyMap()); + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); } } } @@ -223,7 +238,7 @@ public void testParseRequestConfig_DefaultModel() { } private ElserMlNodeService createService(Client client) { - var context = new InferenceServicePlugin.InferenceServiceFactoryContext(client); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); return new ElserMlNodeService(context); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 9cd7a4b4eed2c..0d57e90dcd31b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -273,7 +273,12 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOE getSecretSettingsMap("secret") ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); @@ -301,7 +306,12 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()) + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat( @@ -324,7 +334,12 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUr getSecretSettingsMap("secret") ); - var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); assertThat(model, instanceOf(OpenAiEmbeddingsModel.class)); @@ -353,7 +368,12 @@ public void testParsePersistedConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()) + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat( @@ -381,7 +401,12 @@ public void testParsePersistedConfig_ThrowsWhenAnExtraKeyExistsInSecretsSettings var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()) + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat( @@ -407,7 +432,12 @@ public void testParsePersistedConfig_ThrowsWhenAnExtraKeyExistsInSecrets() throw var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()) + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat( @@ -435,7 +465,12 @@ public void testParsePersistedConfig_ThrowsWhenAnExtraKeyExistsInServiceSettings var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()) + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat( @@ -463,7 +498,12 @@ public void testParsePersistedConfig_ThrowsWhenAnExtraKeyExistsInTaskSettings() var thrownException = expectThrows( ElasticsearchStatusException.class, - () -> service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config(), persistedConfig.secrets()) + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) ); assertThat( diff --git a/x-pack/plugin/ml/build.gradle b/x-pack/plugin/ml/build.gradle index 2373dc4d54c99..22cdb752d1e8d 100644 --- a/x-pack/plugin/ml/build.gradle +++ b/x-pack/plugin/ml/build.gradle @@ -92,6 +92,7 @@ dependencies { testImplementation project(path: xpackModule('wildcard')) // ml deps api project(':libs:elasticsearch-grok') + api project(':modules:lang-mustache') api "org.apache.commons:commons-math3:3.6.1" api "com.ibm.icu:icu4j:${versions.icu4j}" api "org.apache.lucene:lucene-analysis-icu:${versions.lucene}" diff --git a/x-pack/plugin/ml/qa/basic-multi-node/build.gradle b/x-pack/plugin/ml/qa/basic-multi-node/build.gradle index 3268c15879b92..fca019a6fc689 100644 --- a/x-pack/plugin/ml/qa/basic-multi-node/build.gradle +++ b/x-pack/plugin/ml/qa/basic-multi-node/build.gradle @@ -3,6 +3,10 @@ import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.legacy-java-rest-test' +dependencies { + javaRestTestImplementation(project(':modules:lang-mustache')) +} + testClusters.configureEach { testDistribution = 'DEFAULT' numberOfNodes = 3 diff --git a/x-pack/plugin/ml/qa/disabled/build.gradle b/x-pack/plugin/ml/qa/disabled/build.gradle index 97a7b0eed73ad..232700d5f84aa 100644 --- a/x-pack/plugin/ml/qa/disabled/build.gradle +++ b/x-pack/plugin/ml/qa/disabled/build.gradle @@ -2,10 +2,9 @@ import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.legacy-java-rest-test' -//dependencies { -// testImplementation project(":x-pack:plugin:core") -// testImplementation project(path: xpackModule('ml')) -//} +dependencies { + javaRestTestImplementation(project(':modules:lang-mustache')) +} testClusters.configureEach { testDistribution = 'DEFAULT' diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle b/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle new file mode 100644 index 0000000000000..83226acb383c7 --- /dev/null +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle @@ -0,0 +1,12 @@ +apply plugin: 'elasticsearch.internal-java-rest-test' + +dependencies { + javaRestTestImplementation(testArtifact(project(xpackModule('core')))) + javaRestTestImplementation(testArtifact(project(xpackModule('ml')))) + javaRestTestImplementation project(path: xpackModule('inference')) + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') +} + +tasks.named("javaRestTest").configure { + usesDefaultDistribution() +} diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java new file mode 100644 index 0000000000000..c4c3ee016be0e --- /dev/null +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java @@ -0,0 +1,309 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.client.Request; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; +import org.junit.ClassRule; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; + +public class CoordinatedInferenceIngestIT extends ESRestTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .distribution(DistributionType.DEFAULT) + .setting("xpack.license.self_generated.type", "trial") + .setting("xpack.security.enabled", "true") + .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") + .user("x_pack_rest_user", "x-pack-test-password") + .build(); + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @Override + protected Settings restClientSettings() { + String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); + } + + @SuppressWarnings("unchecked") + public void testIngestWithMultipleModelTypes() throws IOException { + // Create an inference service model, dfa model and pytorch model + var inferenceServiceModelId = "is_model"; + var boostedTreeModelId = "boosted_tree_model"; + var pyTorchModelId = "pytorch_model"; + + putInferenceServiceModel(inferenceServiceModelId, TaskType.SPARSE_EMBEDDING); + putBoostedTreeRegressionModel(boostedTreeModelId); + putPyTorchModel(pyTorchModelId); + putPyTorchModelDefinition(pyTorchModelId); + putPyTorchModelVocabulary(List.of("these", "are", "my", "words"), pyTorchModelId); + startDeployment(pyTorchModelId); + + String docs = """ + [ + { + "_source": { + "title": "my", + "body": "these are" + } + }, + { + "_source": { + "title": "are", + "body": "my words" + } + } + ] + """; + + { + var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinition(inferenceServiceModelId), docs); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0))); + var sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0)); + assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1))); + sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1)); + assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + } + + { + var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinition(pyTorchModelId), docs); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0))); + List> results = (List>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0)); + assertThat(results.get(0), contains(1.0, 1.0)); + assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1))); + results = (List>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1)); + assertThat(results.get(0), contains(1.0, 1.0)); + } + + String boostedTreeDocs = Strings.format(""" + [ + { + "_source": %s + }, + { + "_source": %s + } + ] + """, ExampleModels.randomBoostedTreeModelDoc(), ExampleModels.randomBoostedTreeModelDoc()); + { + var responseMap = simulatePipeline( + ExampleModels.boostedTreeRegressionModelPipelineDefinition(boostedTreeModelId), + boostedTreeDocs + ); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + assertEquals(boostedTreeModelId, MapHelper.dig("doc._source.ml.regression.model_id", simulatedDocs.get(0))); + assertNotNull(MapHelper.dig("doc._source.ml.regression.predicted_value", simulatedDocs.get(0))); + assertEquals(boostedTreeModelId, MapHelper.dig("doc._source.ml.regression.model_id", simulatedDocs.get(1))); + assertNotNull(MapHelper.dig("doc._source.ml.regression.predicted_value", simulatedDocs.get(1))); + } + } + + @SuppressWarnings("unchecked") + public void testPipelineConfiguredWithFieldMap() throws IOException { + // Create an inference service model, dfa model and pytorch model + var inferenceServiceModelId = "is_model"; + var boostedTreeModelId = "boosted_tree_model"; + var pyTorchModelId = "pytorch_model"; + + putInferenceServiceModel(inferenceServiceModelId, TaskType.SPARSE_EMBEDDING); + putBoostedTreeRegressionModel(boostedTreeModelId); + putPyTorchModel(pyTorchModelId); + putPyTorchModelDefinition(pyTorchModelId); + putPyTorchModelVocabulary(List.of("these", "are", "my", "words"), pyTorchModelId); + startDeployment(pyTorchModelId); + + String docs = """ + [ + { + "_source": { + "body": "these are" + } + }, + { + "_source": { + "body": "my words" + } + } + ] + """; + + { + var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinitionWithFieldMap(pyTorchModelId), docs); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.inference.model_id", simulatedDocs.get(0))); + List> results = (List>) MapHelper.dig( + "doc._source.ml.inference.predicted_value", + simulatedDocs.get(0) + ); + assertThat(results.get(0), contains(1.0, 1.0)); + assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.inference.model_id", simulatedDocs.get(1))); + results = (List>) MapHelper.dig("doc._source.ml.inference.predicted_value", simulatedDocs.get(1)); + assertThat(results.get(0), contains(1.0, 1.0)); + } + + { + // Inference service models cannot be configured with the field map + var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinitionWithFieldMap(inferenceServiceModelId), docs); + var simulatedDocs = (List>) responseMap.get("docs"); + var errorMsg = (String) MapHelper.dig("error.reason", simulatedDocs.get(0)); + assertThat(errorMsg, containsString("[is_model] is configured for the _inference API and does not accept documents as input")); + assertThat(simulatedDocs, hasSize(2)); + } + + } + + @SuppressWarnings("unchecked") + public void testWithUndeployedPyTorchModel() throws IOException { + var pyTorchModelId = "test-undeployed"; + + putPyTorchModel(pyTorchModelId); + putPyTorchModelDefinition(pyTorchModelId); + putPyTorchModelVocabulary(List.of("these", "are", "my", "words"), pyTorchModelId); + + String docs = """ + [ + { + "_source": { + "title": "my", + "body": "these are" + } + }, + { + "_source": { + "title": "are", + "body": "my words" + } + } + ] + """; + + { + var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinition(pyTorchModelId), docs); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + var errorMsg = (String) MapHelper.dig("error.reason", simulatedDocs.get(0)); + assertEquals("[" + pyTorchModelId + "] is not an inference service model or a deployed ml model", errorMsg); + } + + { + var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinitionWithFieldMap(pyTorchModelId), docs); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + var errorMsg = (String) MapHelper.dig("error.reason", simulatedDocs.get(0)); + assertEquals( + "Model [" + pyTorchModelId + "] must be deployed to use. Please deploy with the start trained model deployment API.", + errorMsg + ); + } + } + + private Map putInferenceServiceModel(String modelId, TaskType taskType) throws IOException { + String endpoint = org.elasticsearch.common.Strings.format("_inference/%s/%s", taskType, modelId); + var request = new Request("PUT", endpoint); + var modelConfig = ExampleModels.mockServiceModelConfig(); + request.setJsonEntity(modelConfig); + var response = client().performRequest(request); + return entityAsMap(response); + } + + private void putPyTorchModel(String modelId) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + modelId); + var modelConfiguration = ExampleModels.pytorchPassThroughModelConfig(); + request.setJsonEntity(modelConfiguration); + client().performRequest(request); + } + + protected void putPyTorchModelVocabulary(List vocabulary, String modelId) throws IOException { + List vocabularyWithPad = new ArrayList<>(); + vocabularyWithPad.add("[PAD]"); + vocabularyWithPad.add("[UNK]"); + vocabularyWithPad.addAll(vocabulary); + String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(",")); + + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary"); + request.setJsonEntity(Strings.format(""" + { "vocabulary": [%s] } + """, quotedWords)); + client().performRequest(request); + } + + protected Map simulatePipeline(String pipelineDef, String docs) throws IOException { + String simulate = Strings.format(""" + { + "pipeline": %s, + "docs": %s + }""", pipelineDef, docs); + + Request request = new Request("POST", "_ingest/pipeline/_simulate?error_trace=true"); + request.setJsonEntity(simulate); + return entityAsMap(client().performRequest(request)); + } + + protected void putPyTorchModelDefinition(String modelId) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0"); + String body = Strings.format( + """ + {"total_definition_length":%s,"definition": "%s","total_parts": 1}""", + ExampleModels.RAW_PYTORCH_MODEL_SIZE, + ExampleModels.BASE_64_ENCODED_PYTORCH_MODEL + ); + request.setJsonEntity(body); + client().performRequest(request); + } + + protected void startDeployment(String modelId) throws IOException { + String endPoint = "/_ml/trained_models/" + + modelId + + "/deployment/_start?timeout=40s&wait_for=started&threads_per_allocation=1&number_of_allocations=1"; + + Request request = new Request("POST", endPoint); + client().performRequest(request); + } + + private void putBoostedTreeRegressionModel(String modelId) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + modelId); + var modelConfiguration = ExampleModels.boostedTreeRegressionModel(); + request.setJsonEntity(modelConfiguration); + client().performRequest(request); + } + + public Map getModel(String modelId, TaskType taskType) throws IOException { + var endpoint = org.elasticsearch.common.Strings.format("_inference/%s/%s", taskType, modelId); + var request = new Request("GET", endpoint); + var reponse = client().performRequest(request); + return entityAsMap(reponse); + } +} diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExampleModels.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExampleModels.java new file mode 100644 index 0000000000000..f9f4d6bf474e9 --- /dev/null +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExampleModels.java @@ -0,0 +1,305 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Strings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.Base64; +import java.util.Map; + +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.elasticsearch.test.ESTestCase.randomIntBetween; + +public class ExampleModels { + + static final String BASE_64_ENCODED_PYTORCH_MODEL = + "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" + + "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" + + "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" + + "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh" + + "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele" + + "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k" + + "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ" + + "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq" + + "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7" + + "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3" + + "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28" + + "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw" + + "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW" + + "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0" + + "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts" + + "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs" + + "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn" + + "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU" + + "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" + + "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" + + "AAJIEAAAAAA=="; + static final long RAW_PYTORCH_MODEL_SIZE; // size of the model before base64 encoding + static { + RAW_PYTORCH_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_PYTORCH_MODEL).length; + } + + static String pytorchPassThroughModelConfig() { + return """ + { + "description": "simple model for testing", + "model_type": "pytorch", + "inference_config": { + "pass_through": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + } + """; + } + + static String mockServiceModelConfig() { + return org.elasticsearch.common.Strings.format(""" + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + "temperature": 3 + } + } + """); + } + + private static final String REGRESSION_DEFINITION = """ + { "preprocessors": [ + { + "one_hot_encoding": { + "field": "col1", + "hot_map": { + "male": "col1_male", + "female": "col1_female" + } + } + }, + { + "target_mean_encoding": { + "field": "col2", + "feature_name": "col2_encoded", + "target_map": { + "S": 5.0, + "M": 10.0, + "L": 20 + }, + "default_value": 5.0 + } + }, + { + "frequency_encoding": { + "field": "col3", + "feature_name": "col3_encoded", + "frequency_map": { + "none": 0.75, + "true": 0.10, + "false": 0.15 + } + } + } + ], + "trained_model": { + "ensemble": { + "feature_names": [ + "col1_male", + "col1_female", + "col2_encoded", + "col3_encoded", + "col4" + ], + "aggregate_output": { + "weighted_sum": { + "weights": [ + 0.5, + 0.5 + ] + } + }, + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": [ + "col1_male", + "col1_female", + "col4" + ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12.0, + "threshold": 10.0, + "decision_type": "lte", + "number_samples": 300, + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "number_samples": 100, + "leaf_value": 1 + }, + { + "node_index": 2, + "number_samples": 200, + "leaf_value": 2 + } + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": [ + "col2_encoded", + "col3_encoded", + "col4" + ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12.0, + "threshold": 10.0, + "decision_type": "lte", + "default_left": true, + "number_samples": 150, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "number_samples": 50, + "leaf_value": 1 + }, + { + "node_index": 2, + "number_samples": 100, + "leaf_value": 2 + } + ], + "target_type": "regression" + } + } + ] + } + } + }"""; + + public static String boostedTreeRegressionModel() { + return Strings.format(""" + { + "input": { + "field_names": [ + "col1", + "col2", + "col3", + "col4" + ] + }, + "description": "test model for regression", + "inference_config": { + "regression": {} + }, + "definition": %s + }""", REGRESSION_DEFINITION); + } + + public static String nlpModelPipelineDefinition(String modelId) { + return Strings.format(""" + { + "processors": [ + { + "inference": { + "model_id": "%s", + "input_output": { + "input_field": "body", + "output_field": "ml.body" + } + } + } + ] + }""", modelId); + } + + public static String nlpModelPipelineDefinitionWithFieldMap(String modelId) { + return Strings.format(""" + { + "processors": [ + { + "inference": { + "model_id": "%s", + "field_map": { + "body": "input" + } + } + } + ] + }""", modelId); + } + + public static String boostedTreeRegressionModelPipelineDefinition(String modelId) { + return Strings.format(""" + { + "processors": [ + { + "inference": { + "target_field": "ml.regression", + "model_id": "%s", + "inference_config": { + "regression": {} + }, + "field_map": { + "col1": "col1", + "col2": "col2", + "col3": "col3", + "col4": "col4" + } + } + } + ] + }""", modelId); + } + + public static String randomBoostedTreeModelDoc() throws IOException { + Map values = Map.of( + "col1", + randomFrom("female", "male"), + "col2", + randomFrom("S", "M", "L", "XL"), + "col3", + randomFrom("true", "false", "none", "other"), + "col4", + randomIntBetween(0, 10) + ); + + try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(values)) { + return XContentHelper.convertToJson(BytesReference.bytes(xContentBuilder), false, XContentType.JSON); + } + } + + private ExampleModels() {} +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle b/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle index ae537f865e65f..db53b9aec7f1f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle @@ -1,4 +1,3 @@ -import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.legacy-java-rest-test' dependencies { @@ -6,6 +5,7 @@ dependencies { javaRestTestImplementation(testArtifact(project(xpackModule('ml')))) javaRestTestImplementation(testArtifact(project(xpackModule('security')))) javaRestTestImplementation project(path: ':modules:ingest-common') + javaRestTestImplementation(project(':modules:lang-mustache')) javaRestTestImplementation project(path: ':modules:reindex') javaRestTestImplementation project(path: ':modules:transport-netty4') javaRestTestImplementation project(path: xpackModule('autoscaling')) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIT.java index c8c580d2933c1..24bdbe23eb5ab 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIT.java @@ -283,7 +283,7 @@ private void putModelAlias(String modelAlias, String newModel) throws IOExceptio } }"""; - private static final String REGRESSION_CONFIG = Strings.format(""" + public static final String REGRESSION_CONFIG = Strings.format(""" { "input": { "field_names": [ @@ -325,5 +325,4 @@ private void putModel(String modelId, String modelConfiguration) throws IOExcept request.setJsonEntity(modelConfiguration); client().performRequest(request); } - } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextEmbeddingQueryIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextEmbeddingQueryIT.java index 8e425ea071879..82597e16837c6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextEmbeddingQueryIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextEmbeddingQueryIT.java @@ -288,7 +288,7 @@ public void testSearchWithMissingModel() { String indexName = modelId + "-index"; var e = expectThrows(ResponseException.class, () -> textEmbeddingSearch(indexName, "the machine is leaking", modelId, "embedding")); - assertThat(e.getMessage(), containsString("Could not find trained model [missing-model]")); + assertThat(e.getMessage(), containsString("[missing-model] is not an inference service model or a deployed ml model")); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextExpansionQueryIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextExpansionQueryIT.java index dbf489e8abf23..6075391326509 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextExpansionQueryIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TextExpansionQueryIT.java @@ -262,7 +262,7 @@ public void testSearchWithMissingModel() throws IOException { String modelId = "missing-model"; String indexName = modelId + "-index"; var e = expectThrows(ResponseException.class, () -> textExpansionSearch(indexName, "the machine is leaking", modelId, "ml.tokens")); - assertThat(e.getMessage(), containsString("Could not find trained model [missing-model]")); + assertThat(e.getMessage(), containsString("[missing-model] is not an inference service model or a deployed ml model")); } protected Response textExpansionSearch(String index, String modelText, String modelId, String fieldName) throws IOException { diff --git a/x-pack/plugin/ml/qa/single-node-tests/build.gradle b/x-pack/plugin/ml/qa/single-node-tests/build.gradle index eb86ca600d75f..6979ec4dcbd31 100644 --- a/x-pack/plugin/ml/qa/single-node-tests/build.gradle +++ b/x-pack/plugin/ml/qa/single-node-tests/build.gradle @@ -2,6 +2,10 @@ import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.legacy-java-rest-test' +dependencies { + javaRestTestImplementation(project(':modules:lang-mustache')) +} + testClusters.configureEach { testDistribution = 'DEFAULT' setting 'xpack.security.enabled', 'false' diff --git a/x-pack/plugin/ml/src/main/java/module-info.java b/x-pack/plugin/ml/src/main/java/module-info.java index a73c9bdfa32b4..0f3fdd836feca 100644 --- a/x-pack/plugin/ml/src/main/java/module-info.java +++ b/x-pack/plugin/ml/src/main/java/module-info.java @@ -17,6 +17,7 @@ requires org.elasticsearch.grok; requires org.elasticsearch.server; requires org.elasticsearch.xcontent; + requires org.elasticsearch.mustache; requires org.apache.httpcomponents.httpcore; requires org.apache.httpcomponents.httpclient; requires org.apache.httpcomponents.httpasyncclient; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 3320a51009257..db23e7796f862 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.CancelJobModelSnapshotUpgradeAction; import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction; import org.elasticsearch.xpack.core.ml.action.CloseJobAction; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction; @@ -197,6 +198,7 @@ import org.elasticsearch.xpack.ml.action.TransportCancelJobModelSnapshotUpgradeAction; import org.elasticsearch.xpack.ml.action.TransportClearDeploymentCacheAction; import org.elasticsearch.xpack.ml.action.TransportCloseJobAction; +import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction; import org.elasticsearch.xpack.ml.action.TransportCreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction; import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarEventAction; @@ -1573,6 +1575,7 @@ public List getRestHandlers( TransportUpdateTrainedModelAssignmentStateAction.class ) ); + actionHandlers.add(new ActionHandler<>(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class)); } } return actionHandlers; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java new file mode 100644 index 0000000000000..d90c9ec807495 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -0,0 +1,188 @@ + +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentUtils; + +import java.util.ArrayList; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class TransportCoordinatedInferenceAction extends HandledTransportAction< + CoordinatedInferenceAction.Request, + InferModelAction.Response> { + + private final Client client; + private final ClusterService clusterService; + + @Inject + public TransportCoordinatedInferenceAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService + ) { + super( + CoordinatedInferenceAction.NAME, + transportService, + actionFilters, + CoordinatedInferenceAction.Request::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + this.client = client; + this.clusterService = clusterService; + } + + @Override + protected void doExecute(Task task, CoordinatedInferenceAction.Request request, ActionListener listener) { + if (request.getRequestModelType() == CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL) { + // must be an inference service model or ml hosted model + forNlp(request, listener); + } else if (request.hasObjects()) { + // Inference service models do not accept a document map + // If this fails check if the model is an inference service + // model and error accordingly + doInClusterModel(request, wrapCheckForServiceModelOnMissing(request.getModelId(), listener)); + } else { + forNlp(request, listener); + } + } + + private void forNlp(CoordinatedInferenceAction.Request request, ActionListener listener) { + var clusterState = clusterService.state(); + var assignments = TrainedModelAssignmentUtils.modelAssignments(request.getModelId(), clusterState); + if (assignments == null || assignments.isEmpty()) { + doInferenceServiceModel( + request, + ActionListener.wrap( + listener::onResponse, + e -> replaceErrorOnMissing( + e, + () -> new ElasticsearchStatusException( + "[" + request.getModelId() + "] is not an inference service model or a deployed ml model", + RestStatus.NOT_FOUND + ), + listener + ) + ) + ); + } else { + doInClusterModel(request, listener); + } + } + + private void doInferenceServiceModel(CoordinatedInferenceAction.Request request, ActionListener listener) { + executeAsyncWithOrigin( + client, + INFERENCE_ORIGIN, + InferenceAction.INSTANCE, + new InferenceAction.Request(TaskType.ANY, request.getModelId(), request.getInputs(), request.getTaskSettings()), + ActionListener.wrap(r -> listener.onResponse(translateInferenceServiceResponse(r.getResults())), listener::onFailure) + ); + } + + private void doInClusterModel(CoordinatedInferenceAction.Request request, ActionListener listener) { + var inferModelRequest = translateRequest(request); + executeAsyncWithOrigin(client, ML_ORIGIN, InferModelAction.INSTANCE, inferModelRequest, listener); + } + + static InferModelAction.Request translateRequest(CoordinatedInferenceAction.Request request) { + InferenceConfigUpdate inferenceConfigUpdate = request.getInferenceConfigUpdate() == null + ? EmptyConfigUpdate.INSTANCE + : request.getInferenceConfigUpdate(); + + var inferModelRequest = request.hasObjects() + ? InferModelAction.Request.forIngestDocs( + request.getModelId(), + request.getObjectsToInfer(), + inferenceConfigUpdate, + request.getPreviouslyLicensed(), + request.getInferenceTimeout() + ) + : InferModelAction.Request.forTextInput( + request.getModelId(), + inferenceConfigUpdate, + request.getInputs(), + request.getPreviouslyLicensed(), + request.getInferenceTimeout() + ); + inferModelRequest.setPrefixType(request.getPrefixType()); + inferModelRequest.setHighPriority(request.getHighPriority()); + return inferModelRequest; + } + + private ActionListener wrapCheckForServiceModelOnMissing( + String modelId, + ActionListener listener + ) { + return ActionListener.wrap(listener::onResponse, originalError -> { + if (ExceptionsHelper.unwrapCause(originalError) instanceof ResourceNotFoundException) { + executeAsyncWithOrigin( + client, + INFERENCE_ORIGIN, + GetInferenceModelAction.INSTANCE, + new GetInferenceModelAction.Request(modelId, TaskType.ANY), + ActionListener.wrap( + model -> listener.onFailure( + new ElasticsearchStatusException( + "[" + modelId + "] is configured for the _inference API and does not accept documents as input", + RestStatus.BAD_REQUEST + ) + ), + e -> listener.onFailure(originalError) + ) + ); + } else { + listener.onFailure(originalError); + } + }); + } + + private void replaceErrorOnMissing( + Exception originalError, + Supplier replaceOnMissing, + ActionListener listener + ) { + if (ExceptionsHelper.unwrapCause(originalError) instanceof ResourceNotFoundException) { + listener.onFailure(replaceOnMissing.get()); + } else { + listener.onFailure(originalError); + } + } + + static InferModelAction.Response translateInferenceServiceResponse(InferenceServiceResults inferenceResults) { + var legacyResults = new ArrayList(inferenceResults.transformToLegacyFormat()); + return new InferModelAction.Response(legacyResults, null, false); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentUtils.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentUtils.java index 257c944c08605..3640d8dcb2808 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentUtils.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentUtils.java @@ -7,10 +7,16 @@ package org.elasticsearch.xpack.ml.inference.assignment; +import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; + +import java.util.List; +import java.util.Optional; public class TrainedModelAssignmentUtils { public static final String NODES_CHANGED_REASON = "nodes changed"; @@ -24,5 +30,22 @@ public static RoutingInfo createShuttingDownRoute(RoutingInfo existingRoute) { return routeUpdate.apply(existingRoute); } + public static List modelAssignments(String modelId, ClusterState state) { + String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(state).getModelId(modelId)).orElse(modelId); + + List assignments; + + TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(state); + TrainedModelAssignment assignment = trainedModelAssignmentMetadata.getDeploymentAssignment(concreteModelId); + if (assignment != null) { + assignments = List.of(assignment); + } else { + // look up by model + assignments = trainedModelAssignmentMetadata.getDeploymentsUsingModel(concreteModelId); + } + + return assignments; + } + private TrainedModelAssignmentUtils() {} } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index e600ddd42107f..470605dcb2d9c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -25,11 +25,11 @@ import org.elasticsearch.ingest.Processor; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.MlConfigVersion; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; @@ -170,7 +170,7 @@ private InferenceProcessor( this.client = ExceptionsHelper.requireNonNull(client, "client"); this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor"); this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); - this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + this.inferenceConfig = inferenceConfig; this.ignoreMissing = ignoreMissing; if (configuredWithInputsFields) { @@ -191,7 +191,7 @@ public String getModelId() { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { - InferModelAction.Request request; + CoordinatedInferenceAction.Request request; try { request = buildRequest(ingestDocument); } catch (ElasticsearchStatusException e) { @@ -202,7 +202,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer handleResponse(r, ingestDocument, handler), e -> handler.accept(ingestDocument, e)) ); @@ -223,7 +223,7 @@ void handleResponse(InferModelAction.Response response, IngestDocument ingestDoc } } - InferModelAction.Request buildRequest(IngestDocument ingestDocument) { + CoordinatedInferenceAction.Request buildRequest(IngestDocument ingestDocument) { if (configuredWithInputsFields) { // ignore missing only applies when using an input field list List requestInputs = new ArrayList<>(); @@ -246,10 +246,10 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { } } } - var request = InferModelAction.Request.forTextInput( + var request = CoordinatedInferenceAction.Request.forTextInput( modelId, - inferenceConfig, requestInputs, + inferenceConfig, previouslyLicensed, InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST ); @@ -263,12 +263,13 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { } LocalModel.mapFieldsIfNecessary(fields, fieldMap); - var request = InferModelAction.Request.forIngestDocs( + var request = CoordinatedInferenceAction.Request.forMapInput( modelId, List.of(fields), inferenceConfig, previouslyLicensed, - InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, + CoordinatedInferenceAction.Request.RequestModelType.UNKNOWN ); request.setPrefixType(TrainedModelPrefixStrings.PrefixType.INGEST); return request; @@ -409,15 +410,9 @@ public InferenceProcessor create( String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); - InferenceConfigUpdate inferenceConfigUpdate; + InferenceConfigUpdate inferenceConfigUpdate = null; Map inferenceConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, INFERENCE_CONFIG); - if (inferenceConfigMap == null) { - if (minNodeVersion.before(EmptyConfigUpdate.minimumSupportedVersion())) { - // an inference config is required when the empty update is not supported - throw newConfigurationException(TYPE, tag, INFERENCE_CONFIG, "required property is missing"); - } - inferenceConfigUpdate = new EmptyConfigUpdate(); - } else { + if (inferenceConfigMap != null) { inferenceConfigUpdate = inferenceConfigUpdateFromMap(inferenceConfigMap); } @@ -445,7 +440,7 @@ public InferenceProcessor create( ); } - if (inferenceConfigUpdate.getResultsField() != null) { + if (inferenceConfigUpdate != null && inferenceConfigUpdate.getResultsField() != null) { throw newConfigurationException( TYPE, tag, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java index 1443ccd687620..42f7d8cf0a3b3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankService.java @@ -9,10 +9,14 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.script.GeneralScriptException; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptType; import org.elasticsearch.script.TemplateScript; +import org.elasticsearch.script.mustache.MustacheInvalidParameterException; +import org.elasticsearch.script.mustache.MustacheScriptEngine; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; @@ -37,11 +41,15 @@ import java.util.Map; import java.util.Optional; +import static java.util.Map.entry; import static org.elasticsearch.script.Script.DEFAULT_TEMPLATE_LANG; import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_CONFIG_QUERY_BAD_FORMAT; public class LearnToRankService { + private static final Map SCRIPT_OPTIONS = Map.ofEntries( + entry(MustacheScriptEngine.DETECT_MISSING_PARAMS_OPTION, Boolean.TRUE.toString()) + ); private final ModelLoadingService modelLoadingService; private final TrainedModelProvider trainedModelProvider; private final ScriptService scriptService; @@ -126,11 +134,6 @@ private LearnToRankConfig applyParams(LearnToRankConfig config, Map featureExtractorBuilders = new ArrayList<>(); for (LearnToRankFeatureExtractorBuilder featureExtractorBuilder : config.getFeatureExtractorBuilders()) { @@ -176,15 +179,25 @@ private QueryExtractorBuilder applyParams(QueryExtractorBuilder queryExtractorBu return queryExtractorBuilder; } - Script script = new Script(ScriptType.INLINE, DEFAULT_TEMPLATE_LANG, templateSource, Collections.emptyMap()); - String parsedTemplate = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params).execute(); - // TODO: handle missing params. - XContentParser parser = XContentType.JSON.xContent().createParser(parserConfiguration, parsedTemplate); - - return new QueryExtractorBuilder( - queryExtractorBuilder.featureName(), - QueryProvider.fromXContent(parser, false, INFERENCE_CONFIG_QUERY_BAD_FORMAT) - ); + try { + Script script = new Script(ScriptType.INLINE, DEFAULT_TEMPLATE_LANG, templateSource, SCRIPT_OPTIONS, Collections.emptyMap()); + String parsedTemplate = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params).execute(); + XContentParser parser = XContentType.JSON.xContent().createParser(parserConfiguration, parsedTemplate); + + return new QueryExtractorBuilder( + queryExtractorBuilder.featureName(), + QueryProvider.fromXContent(parser, false, INFERENCE_CONFIG_QUERY_BAD_FORMAT) + ); + } catch (GeneralScriptException e) { + if (e.getRootCause().getClass().getName().equals(MustacheInvalidParameterException.class.getName())) { + // Can't use instanceof since it return unexpected result. + return new QueryExtractorBuilder( + queryExtractorBuilder.featureName(), + QueryProvider.fromParsedQuery(new MatchNoneQueryBuilder()) + ); + } + throw e; + } } private String templateSource(QueryProvider queryProvider) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 7cdeeb3d559ec..12019e93ba713 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -24,6 +24,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; @@ -124,10 +125,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return weightedTokensToQuery(fieldName, weightedTokensSupplier.get(), queryRewriteContext); } - InferModelAction.Request inferRequest = InferModelAction.Request.forTextInput( + CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( modelId, - TextExpansionConfigUpdate.EMPTY_UPDATE, List.of(modelText), + TextExpansionConfigUpdate.EMPTY_UPDATE, false, InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API ); @@ -136,32 +137,38 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws SetOnce textExpansionResultsSupplier = new SetOnce<>(); queryRewriteContext.registerAsyncAction((client, listener) -> { - executeAsyncWithOrigin(client, ML_ORIGIN, InferModelAction.INSTANCE, inferRequest, ActionListener.wrap(inferenceResponse -> { + executeAsyncWithOrigin( + client, + ML_ORIGIN, + CoordinatedInferenceAction.INSTANCE, + inferRequest, + ActionListener.wrap(inferenceResponse -> { - if (inferenceResponse.getInferenceResults().isEmpty()) { - listener.onFailure(new IllegalStateException("inference response contain no results")); - return; - } + if (inferenceResponse.getInferenceResults().isEmpty()) { + listener.onFailure(new IllegalStateException("inference response contain no results")); + return; + } - if (inferenceResponse.getInferenceResults().get(0) instanceof TextExpansionResults textExpansionResults) { - textExpansionResultsSupplier.set(textExpansionResults); - listener.onResponse(null); - } else if (inferenceResponse.getInferenceResults().get(0) instanceof WarningInferenceResults warning) { - listener.onFailure(new IllegalStateException(warning.getWarning())); - } else { - listener.onFailure( - new IllegalStateException( - "expected a result of type [" - + TextExpansionResults.NAME - + "] received [" - + inferenceResponse.getInferenceResults().get(0).getWriteableName() - + "]. Is [" - + modelId - + "] a compatible model?" - ) - ); - } - }, listener::onFailure)); + if (inferenceResponse.getInferenceResults().get(0) instanceof TextExpansionResults textExpansionResults) { + textExpansionResultsSupplier.set(textExpansionResults); + listener.onResponse(null); + } else if (inferenceResponse.getInferenceResults().get(0) instanceof WarningInferenceResults warning) { + listener.onFailure(new IllegalStateException(warning.getWarning())); + } else { + listener.onFailure( + new IllegalStateException( + "expected a result of type [" + + TextExpansionResults.NAME + + "] received [" + + inferenceResponse.getInferenceResults().get(0).getWriteableName() + + "]. Is [" + + modelId + + "] a compatible model?" + ) + ); + } + }, listener::onFailure) + ); }); return new TextExpansionQueryBuilder(this, textExpansionResultsSupplier); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java index 72663b3f8a7bd..bd0916065ec5f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; @@ -93,17 +94,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void buildVector(Client client, ActionListener listener) { - InferModelAction.Request inferRequest = InferModelAction.Request.forTextInput( + CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( modelId, - TextEmbeddingConfigUpdate.EMPTY_INSTANCE, List.of(modelText), + TextEmbeddingConfigUpdate.EMPTY_INSTANCE, false, InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API ); inferRequest.setHighPriority(true); inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH); - executeAsyncWithOrigin(client, ML_ORIGIN, InferModelAction.INSTANCE, inferRequest, ActionListener.wrap(response -> { + executeAsyncWithOrigin(client, ML_ORIGIN, CoordinatedInferenceAction.INSTANCE, inferRequest, ActionListener.wrap(response -> { if (response.getInferenceResults().isEmpty()) { listener.onFailure(new IllegalStateException("text embedding inference response contain no results")); return; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/KDETests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/KDETests.java index e794b5afb258c..80d5a3ad71136 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/KDETests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/KDETests.java @@ -23,6 +23,7 @@ public void testEmpty() { assertThat(kde.data(), equalTo(new double[0])); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102876") public void testCdfAndSf() { double[] data = DoubleStream.generate(() -> randomDoubleBetween(0.0, 100.0, true)).limit(101).toArray(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 5c98ac53c7228..0698c266400b0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -313,37 +313,6 @@ public void testCreateProcessorWithTooOldMinNodeVersionNlp() throws IOException }); } - public void testCreateProcessorWithEmptyConfigNotSupportedOnOldNode() throws IOException { - Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); - - includeNodeInfoValues.forEach(includeNodeInfo -> { - InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( - client, - clusterService, - Settings.EMPTY, - includeNodeInfo - ); - try { - processorFactory.accept(builderClusterStateWithModelReferences(MlConfigVersion.V_7_5_0, "model1")); - } catch (IOException ioe) { - throw new AssertionError(ioe.getMessage()); - } - - Map minimalConfig = new HashMap<>() { - { - put(InferenceProcessor.MODEL_ID, "my_model"); - put(InferenceProcessor.TARGET_FIELD, "result"); - } - }; - - ElasticsearchException ex = expectThrows( - ElasticsearchException.class, - () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, minimalConfig) - ); - assertThat(ex.getMessage(), equalTo("[inference_config] required property is missing")); - }); - } - public void testCreateProcessor() { Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 4821efa29631f..6feb014309fe9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -420,7 +420,7 @@ public void testHandleResponseLicenseChanged() { IngestDocument document = TestIngestDocument.emptyIngestDocument(); - assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(false)); + assertThat(inferenceProcessor.buildRequest(document).getPreviouslyLicensed(), is(false)); InferModelAction.Response response = new InferModelAction.Response( Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), @@ -432,7 +432,7 @@ public void testHandleResponseLicenseChanged() { assertThat(ex, is(nullValue())); }); - assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true)); + assertThat(inferenceProcessor.buildRequest(document).getPreviouslyLicensed(), is(true)); response = new InferModelAction.Response( Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), @@ -445,7 +445,7 @@ public void testHandleResponseLicenseChanged() { assertThat(ex, is(nullValue())); }); - assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true)); + assertThat(inferenceProcessor.buildRequest(document).getPreviouslyLicensed(), is(true)); inferenceProcessor.handleResponse(response, document, (doc, ex) -> { assertThat(doc, is(not(nullValue()))); @@ -608,8 +608,8 @@ public void testBuildRequestWithInputFields() { document.setFieldValue("unrelated", "text"); var request = inferenceProcessor.buildRequest(document); - assertTrue(request.getObjectsToInfer().isEmpty()); - var requestInputs = request.getTextInput(); + assertNull(request.getObjectsToInfer()); + var requestInputs = request.getInputs(); assertThat(requestInputs, contains("body_text", "title_text")); assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout()); assertEquals(TrainedModelPrefixStrings.PrefixType.INGEST, request.getPrefixType()); @@ -683,7 +683,7 @@ public void testBuildRequestWithInputFields_MissingField() { document.setFieldValue("unrelated", 1.0); var request = inferenceProcessor.buildRequest(document); - var requestInputs = request.getTextInput(); + var requestInputs = request.getInputs(); assertThat(requestInputs, contains("body_text", "")); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java index e4d0225637fa1..a2cd0ff8856c6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.ml.inference.ltr; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.ScriptModule; import org.elasticsearch.script.ScriptService; @@ -26,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder; import org.elasticsearch.xpack.core.ml.ltr.MlLTRNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.QueryProviderTests; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -98,7 +101,8 @@ public void testLoadLearnToRankConfig() throws Exception { ); ActionListener listener = mock(ActionListener.class); learnToRankService.loadLearnToRankConfig(GOOD_MODEL, Collections.emptyMap(), listener); - assertBusy(() -> verify(listener).onResponse(eq((LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig()))); + + verify(listener).onResponse(eq((LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig())); } @SuppressWarnings("unchecked") @@ -111,7 +115,8 @@ public void testLoadMissingLearnToRankConfig() throws Exception { ); ActionListener listener = mock(ActionListener.class); learnToRankService.loadLearnToRankConfig("non-existing-model", Collections.emptyMap(), listener); - assertBusy(() -> verify(listener).onFailure(isA(ResourceNotFoundException.class))); + + verify(listener).onFailure(isA(ResourceNotFoundException.class)); } @SuppressWarnings("unchecked") @@ -124,7 +129,8 @@ public void testLoadBadLearnToRankConfig() throws Exception { ); ActionListener listener = mock(ActionListener.class); learnToRankService.loadLearnToRankConfig(BAD_MODEL, Collections.emptyMap(), listener); - assertBusy(() -> verify(listener).onFailure(isA(ElasticsearchStatusException.class))); + + verify(listener).onFailure(isA(ElasticsearchStatusException.class)); } @SuppressWarnings("unchecked") @@ -136,27 +142,48 @@ public void testLoadLearnToRankConfigWithTemplate() throws Exception { xContentRegistry() ); - // When no parameters are provided we expect the templated queries not being part of the retrieved config. - ActionListener noParamsListener = mock(ActionListener.class); - learnToRankService.loadLearnToRankConfig(TEMPLATED_GOOD_MODEL, Collections.emptyMap(), noParamsListener); - assertBusy(() -> verify(noParamsListener).onResponse(argThat(retrievedConfig -> { - assertThat(retrievedConfig.getFeatureExtractorBuilders(), hasSize(2)); - assertEquals(retrievedConfig, TEMPLATED_GOOD_MODEL_CONFIG.getInferenceConfig()); - return true; - }))); + // When no parameters are provided we expect query to be rewritten into a match_none query. + { + ActionListener listener = mock(ActionListener.class); + SetOnce retrievedConfig = new SetOnce<>(); + + doAnswer(i -> { + retrievedConfig.set(i.getArgument(0, LearnToRankConfig.class)); + return null; + }).when(listener).onResponse(any()); + learnToRankService.loadLearnToRankConfig(TEMPLATED_GOOD_MODEL, null, listener); + + assertNotNull(retrievedConfig.get()); + assertThat(retrievedConfig.get().getFeatureExtractorBuilders(), hasSize(2)); + + assertEquals( + retrievedConfig.get(), + LearnToRankConfig.builder((LearnToRankConfig) TEMPLATED_GOOD_MODEL_CONFIG.getInferenceConfig()) + .setLearnToRankFeatureExtractorBuilders( + List.of( + new QueryExtractorBuilder("feature_1", QueryProvider.fromParsedQuery(new MatchNoneQueryBuilder())), + new QueryExtractorBuilder("feature_2", QueryProvider.fromParsedQuery(new MatchNoneQueryBuilder())) + ) + ) + .build() + ); + } // Now testing when providing all the params of the template. - ActionListener allParamsListener = mock(ActionListener.class); - learnToRankService.loadLearnToRankConfig( - TEMPLATED_GOOD_MODEL, - Map.ofEntries(Map.entry("foo_param", "foo"), Map.entry("bar_param", "bar")), - allParamsListener - ); - assertBusy(() -> verify(allParamsListener).onResponse(argThat(retrievedConfig -> { - assertThat(retrievedConfig.getFeatureExtractorBuilders(), hasSize(2)); - assertEquals(retrievedConfig, GOOD_MODEL_CONFIG.getInferenceConfig()); - return true; - }))); + { + ActionListener listener = mock(ActionListener.class); + learnToRankService.loadLearnToRankConfig( + TEMPLATED_GOOD_MODEL, + Map.ofEntries(Map.entry("foo_param", "foo"), Map.entry("bar_param", "bar")), + listener + ); + + verify(listener).onResponse(argThat(retrievedConfig -> { + assertThat(retrievedConfig.getFeatureExtractorBuilders(), hasSize(2)); + assertEquals(retrievedConfig, GOOD_MODEL_CONFIG.getInferenceConfig()); + return true; + })); + } } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index d8edea137330f..5e414a7f997d5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; @@ -72,14 +73,15 @@ public void testMustRewrite() { @Override protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException { return method.equals(Client.class.getMethod("execute", ActionType.class, ActionRequest.class, ActionListener.class)) - && (args[0] instanceof InferModelAction); + && (args[0] instanceof CoordinatedInferenceAction); } @Override protected Object simulateMethod(Method method, Object[] args) { - InferModelAction.Request request = (InferModelAction.Request) args[1]; + CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1]; assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout()); assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType()); + assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, request.getRequestModelType()); // Randomisation cannot be used here as {@code #doAssertLuceneQuery} // asserts that 2 rewritten queries are the same @@ -89,7 +91,7 @@ protected Object simulateMethod(Method method, Object[] args) { } var response = InferModelAction.Response.builder() - .setId(request.getId()) + .setId(request.getModelId()) .addInferenceResults(List.of(new TextExpansionResults("foo", tokens, randomBoolean()))) .build(); @SuppressWarnings("unchecked") // We matched the method above. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java index 8506be491f7e1..a44aa9404f4f9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.test.AbstractQueryVectorBuilderTestCase; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; @@ -34,13 +35,14 @@ protected List additionalPlugins() { @Override protected void doAssertClientRequest(ActionRequest request, TextEmbeddingQueryVectorBuilder builder) { - assertThat(request, instanceOf(InferModelAction.Request.class)); - InferModelAction.Request inferRequest = (InferModelAction.Request) request; - assertThat(inferRequest.getTextInput(), hasSize(1)); - assertEquals(builder.getModelText(), inferRequest.getTextInput().get(0)); - assertEquals(builder.getModelId(), inferRequest.getId()); + assertThat(request, instanceOf(CoordinatedInferenceAction.Request.class)); + CoordinatedInferenceAction.Request inferRequest = (CoordinatedInferenceAction.Request) request; + assertThat(inferRequest.getInputs(), hasSize(1)); + assertEquals(builder.getModelText(), inferRequest.getInputs().get(0)); + assertEquals(builder.getModelId(), inferRequest.getModelId()); assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, inferRequest.getInferenceTimeout()); assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, inferRequest.getPrefixType()); + assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, inferRequest.getRequestModelType()); } public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuilder builder) { diff --git a/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java b/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java index 955cf0396326b..406ea50315de0 100644 --- a/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java +++ b/x-pack/plugin/old-lucene-versions/src/main/java/org/elasticsearch/xpack/lucene/bwc/OldLuceneVersions.java @@ -17,7 +17,6 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDecider; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; @@ -78,15 +77,14 @@ public class OldLuceneVersions extends Plugin implements IndexStorePlugin, Clust @Override public Collection createComponents(PluginServices services) { - ClusterService clusterService = services.clusterService(); ThreadPool threadPool = services.threadPool(); - this.failShardsListener.set(new FailShardsOnInvalidLicenseClusterListener(getLicenseState(), clusterService.getRerouteService())); + this.failShardsListener.set(new FailShardsOnInvalidLicenseClusterListener(getLicenseState(), services.rerouteService())); if (DiscoveryNode.isMasterNode(services.environment().settings())) { // We periodically look through the indices and identify if there are any archive indices, // then marking the feature as used. We do this on each master node so that if one master fails, the // continue reporting usage state. - var usageTracker = new ArchiveUsageTracker(getLicenseState(), clusterService::state); + var usageTracker = new ArchiveUsageTracker(getLicenseState(), services.clusterService()::state); threadPool.scheduleWithFixedDelay(usageTracker, TimeValue.timeValueMinutes(15), threadPool.generic()); } return List.of(); diff --git a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/util/SpatialCoordinateTypesTests.java b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/util/SpatialCoordinateTypesTests.java index d4db20faf0050..67e72d530e2e0 100644 --- a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/util/SpatialCoordinateTypesTests.java +++ b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/util/SpatialCoordinateTypesTests.java @@ -27,6 +27,7 @@ public class SpatialCoordinateTypesTests extends ESTestCase { record TestTypeFunctions(Supplier randomPoint, Function error) {} + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102863") public void testEncoding() { for (var type : types.entrySet()) { for (int i = 0; i < 10; i++) { diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshots.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshots.java index 45c38e52ad9c3..83a38a4d0b328 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshots.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshots.java @@ -321,7 +321,7 @@ public Collection createComponents(PluginServices services) { final List components = new ArrayList<>(); this.repositoriesServiceSupplier = services.repositoriesServiceSupplier(); this.threadPool.set(threadPool); - this.failShardsListener.set(new FailShardsOnInvalidLicenseClusterListener(getLicenseState(), clusterService.getRerouteService())); + this.failShardsListener.set(new FailShardsOnInvalidLicenseClusterListener(getLicenseState(), services.rerouteService())); if (DiscoveryNode.canContainData(settings)) { final CacheService cacheService = new CacheService(settings, clusterService, threadPool, new PersistentCache(nodeEnvironment)); this.cacheService.set(cacheService); @@ -357,12 +357,12 @@ public Collection createComponents(PluginServices services) { threadPool.scheduleWithFixedDelay(usageTracker, TimeValue.timeValueMinutes(15), threadPool.generic()); } - this.allocator.set(new SearchableSnapshotAllocator(client, clusterService.getRerouteService(), frozenCacheInfoService)); + this.allocator.set(new SearchableSnapshotAllocator(client, services.rerouteService(), frozenCacheInfoService)); components.add(new FrozenCacheServiceSupplier(frozenCacheService.get())); components.add(new CacheServiceSupplier(cacheService.get())); if (DiscoveryNode.isMasterNode(settings)) { new SearchableSnapshotIndexMetadataUpgrader(clusterService, threadPool).initialize(); - clusterService.addListener(new RepositoryUuidWatcher(clusterService.getRerouteService())); + clusterService.addListener(new RepositoryUuidWatcher(services.rerouteService())); } return Collections.unmodifiableList(components); } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 11edc66977e6c..b9d005e695459 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -127,9 +127,13 @@ public class Constants { "cluster:admin/xpack/connector/get", "cluster:admin/xpack/connector/list", "cluster:admin/xpack/connector/put", + "cluster:admin/xpack/connector/update_pipeline", + "cluster:admin/xpack/connector/update_scheduling", + "cluster:admin/xpack/connector/update_filtering", "cluster:admin/xpack/connector/sync_job/post", "cluster:admin/xpack/connector/sync_job/delete", - "cluster:admin/xpack/connector/update_scheduling", + "cluster:admin/xpack/connector/sync_job/check_in", + "cluster:admin/xpack/connector/sync_job/cancel", "cluster:admin/xpack/deprecation/info", "cluster:admin/xpack/deprecation/nodes/info", "cluster:admin/xpack/enrich/delete", @@ -280,6 +284,7 @@ public class Constants { "cluster:admin/xpack/watcher/settings/update", "cluster:admin/xpack/watcher/watch/put", "cluster:internal/remote_cluster/nodes", + "cluster:internal/xpack/ml/coordinatedinference", "cluster:internal/xpack/ml/datafeed/isolate", "cluster:internal/xpack/ml/datafeed/running_state", "cluster:internal/xpack/ml/inference/infer", diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java index 3f29944631d42..6773da137ac96 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/SecurityTests.java @@ -772,13 +772,14 @@ public void testSecurityRestHandlerInterceptorCanBeInstalled() throws IllegalAcc settingsModule.getClusterSettings(), settingsModule.getSettingsFilter(), threadPool, - Arrays.asList(security), + List.of(security), null, null, usageService, null, Tracer.NOOP, mock(ClusterService.class), + null, List.of(), RestExtension.allowAll() ); diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java index 3e0e578ded120..caf8ae0e3107b 100644 --- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java +++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java @@ -40,6 +40,7 @@ public class TransportDeleteShutdownNodeAction extends AcknowledgedTransportMasterNodeAction { private static final Logger logger = LogManager.getLogger(TransportDeleteShutdownNodeAction.class); + private final RerouteService rerouteService; private final MasterServiceTaskQueue taskQueue; private static boolean deleteShutdownNodeState(Map shutdownMetadata, Request request) { @@ -89,8 +90,7 @@ public ClusterState execute(BatchExecutionContext batchE taskContext.onFailure(e); continue; } - var reroute = clusterService.getRerouteService(); - taskContext.success(() -> ackAndReroute(request, taskContext.getTask().listener(), reroute)); + taskContext.success(() -> ackAndReroute(request, taskContext.getTask().listener(), rerouteService)); } if (changed == false) { return batchExecutionContext.initialState(); @@ -108,6 +108,7 @@ public ClusterState execute(BatchExecutionContext batchE public TransportDeleteShutdownNodeAction( TransportService transportService, ClusterService clusterService, + RerouteService rerouteService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver @@ -123,6 +124,7 @@ public TransportDeleteShutdownNodeAction( indexNameExpressionResolver, EsExecutors.DIRECT_EXECUTOR_SERVICE ); + this.rerouteService = rerouteService; taskQueue = clusterService.createTaskQueue("delete-node-shutdown", Priority.URGENT, new DeleteShutdownNodeExecutor()); } diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java index 767c128030538..7946bb7e46627 100644 --- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java +++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java @@ -43,6 +43,7 @@ public class TransportPutShutdownNodeAction extends AcknowledgedTransportMasterNodeAction { private static final Logger logger = LogManager.getLogger(TransportPutShutdownNodeAction.class); + private final RerouteService rerouteService; private final MasterServiceTaskQueue taskQueue; private final PutShutdownNodeExecutor executor = new PutShutdownNodeExecutor(); @@ -137,8 +138,7 @@ public ClusterState execute(BatchExecutionContext batchExec taskContext.onFailure(e); continue; } - var reroute = clusterService.getRerouteService(); - taskContext.success(() -> ackAndMaybeReroute(request, taskContext.getTask().listener(), reroute)); + taskContext.success(() -> ackAndMaybeReroute(request, taskContext.getTask().listener(), rerouteService)); } if (changed == false) { return batchExecutionContext.initialState(); @@ -156,6 +156,7 @@ public ClusterState execute(BatchExecutionContext batchExec public TransportPutShutdownNodeAction( TransportService transportService, ClusterService clusterService, + RerouteService rerouteService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver @@ -171,6 +172,7 @@ public TransportPutShutdownNodeAction( indexNameExpressionResolver, EsExecutors.DIRECT_EXECUTOR_SERVICE ); + this.rerouteService = rerouteService; taskQueue = clusterService.createTaskQueue("put-shutdown", Priority.URGENT, new PutShutdownNodeExecutor()); } diff --git a/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java b/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java index cf28bf9922b24..82b1427fc8e4f 100644 --- a/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java +++ b/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.metadata.NodesShutdownMetadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.test.ESTestCase; @@ -57,6 +58,7 @@ public void init() { var threadPool = mock(ThreadPool.class); var transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(threadPool); clusterService = mock(ClusterService.class); + var rerouteService = mock(RerouteService.class); var actionFilters = mock(ActionFilters.class); var indexNameExpressionResolver = mock(IndexNameExpressionResolver.class); when(clusterService.createTaskQueue(any(), any(), Mockito.>any())).thenReturn( @@ -65,6 +67,7 @@ public void init() { action = new TransportDeleteShutdownNodeAction( transportService, clusterService, + rerouteService, threadPool, actionFilters, indexNameExpressionResolver diff --git a/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java b/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java index cbd51ceebc729..1ea85f4ef07cf 100644 --- a/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java +++ b/x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.cluster.ClusterStateTaskExecutor.TaskContext; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata.Type; +import org.elasticsearch.cluster.routing.RerouteService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.core.TimeValue; @@ -63,6 +64,7 @@ public void init() { var threadPool = mock(ThreadPool.class); var transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(threadPool); clusterService = mock(ClusterService.class); + var rerouteService = mock(RerouteService.class); var actionFilters = mock(ActionFilters.class); var indexNameExpressionResolver = mock(IndexNameExpressionResolver.class); when(clusterService.createTaskQueue(any(), any(), Mockito.>any())).thenReturn( @@ -71,6 +73,7 @@ public void init() { action = new TransportPutShutdownNodeAction( transportService, clusterService, + rerouteService, threadPool, actionFilters, indexNameExpressionResolver diff --git a/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java b/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java index 954b5ba024418..c2e3786a1afe7 100644 --- a/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java +++ b/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java @@ -277,6 +277,7 @@ public void testOperatorControllerFromJSONContent() throws IOException { ReservedClusterStateService controller = new ReservedClusterStateService( clusterService, + null, List.of(new ReservedClusterSettingsAction(clusterSettings), new ReservedRepositoryAction(repositoriesService)) ); @@ -347,6 +348,7 @@ public void testOperatorControllerFromJSONContent() throws IOException { controller = new ReservedClusterStateService( clusterService, + null, List.of( new ReservedClusterSettingsAction(clusterSettings), new ReservedSnapshotAction(), diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml index 1876d1a6d3881..e768a6b348959 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml @@ -1,8 +1,10 @@ --- "Coalesce and to_ip functions": - skip: - version: " - 8.11.99" - reason: "fixes in 8.12 or later" + version: all + reason: "AwaitsFix https://github.com/elastic/elasticsearch/issues/102871" + # version: " - 8.11.99" + # reason: "fixes in 8.12 or later" features: warnings - do: bulk: diff --git a/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java index 2b3724eeae8d0..f67d1e4c37b28 100644 --- a/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Strings; +import org.elasticsearch.core.UpdateForV9; import org.elasticsearch.upgrades.FullClusterRestartUpgradeStatus; import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus; import org.junit.Before; @@ -90,7 +91,10 @@ protected Settings restClientSettings() { } public void testDeploymentSurvivesRestart() throws Exception { - assumeTrue("NLP model deployments added in 8.0", getOldClusterVersion().onOrAfter(Version.V_8_0_0)); + @UpdateForV9 // upgrade will always be from v8, condition can be removed + var originalClusterAtLeastV8 = getOldClusterVersion().onOrAfter(Version.V_8_0_0); + // These tests assume the original cluster is v8 - testing for features on the _current_ cluster will break for NEW + assumeTrue("NLP model deployments added in 8.0", originalClusterAtLeastV8); String modelId = "trained-model-full-cluster-restart"; diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/AbstractUpgradeTestCase.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/AbstractUpgradeTestCase.java index 865ba0c07cfeb..128fd8b47722f 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/AbstractUpgradeTestCase.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/AbstractUpgradeTestCase.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Booleans; +import org.elasticsearch.core.UpdateForV9; import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xpack.test.SecuritySettingsSourceField; import org.junit.Before; @@ -31,13 +32,20 @@ public abstract class AbstractUpgradeTestCase extends ESRestTestCase { ); protected static final String UPGRADE_FROM_VERSION = System.getProperty("tests.upgrade_from_version"); - protected static final boolean SKIP_ML_TESTS = Booleans.parseBoolean(System.getProperty("tests.ml.skip", "false")); - // TODO: replace with feature testing - @Deprecated + protected static boolean isOriginalCluster(String clusterVersion) { + return UPGRADE_FROM_VERSION.equals(clusterVersion); + } + + @Deprecated(forRemoval = true) + @UpdateForV9 + // Tests should be reworked to rely on features from the current cluster (old, mixed or upgraded). + // Version test against the original cluster will be removed protected static boolean isOriginalClusterVersionAtLeast(Version supportedVersion) { - return Version.fromString(UPGRADE_FROM_VERSION).onOrAfter(supportedVersion); + // Always assume non-semantic versions are OK: this method will be removed in V9, we are testing the pre-upgrade cluster version, + // and non-semantic versions are always V8+ + return parseLegacyVersion(UPGRADE_FROM_VERSION).map(x -> x.onOrAfter(supportedVersion)).orElse(true); } @Override diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/ApiKeyBackwardsCompatibilityIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/ApiKeyBackwardsCompatibilityIT.java index 850a94f7133e9..1a37f31bffe79 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/ApiKeyBackwardsCompatibilityIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/ApiKeyBackwardsCompatibilityIT.java @@ -10,7 +10,8 @@ import org.apache.http.HttpHost; import org.apache.http.client.methods.HttpGet; import org.elasticsearch.Build; -import org.elasticsearch.Version; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; @@ -19,6 +20,8 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.test.XContentTestUtils; import org.elasticsearch.test.rest.ObjectPath; +import org.elasticsearch.test.rest.RestTestLegacyFeatures; +import org.elasticsearch.transport.RemoteClusterPortSettings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; @@ -46,15 +49,14 @@ public class ApiKeyBackwardsCompatibilityIT extends AbstractUpgradeTestCase { - public static final Version API_KEY_SUPPORT_REMOTE_INDICES_VERSION = Build.current().isSnapshot() ? Version.V_8_8_0 : Version.V_8_9_1; - private RestClient oldVersionClient = null; private RestClient newVersionClient = null; public void testCreatingAndUpdatingApiKeys() throws Exception { assumeTrue( - "The remote_indices for API Keys are not supported before version " + API_KEY_SUPPORT_REMOTE_INDICES_VERSION, - isOriginalClusterVersionAtLeast(API_KEY_SUPPORT_REMOTE_INDICES_VERSION) == false + "The remote_indices for API Keys are not supported before transport version " + + RemoteClusterPortSettings.TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY, + minimumTransportVersion().before(RemoteClusterPortSettings.TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY) ); switch (CLUSTER_TYPE) { case OLD -> { @@ -182,8 +184,8 @@ private Tuple createOrGrantApiKey(RestClient client, String role "name": "%s", "role_descriptors": %s }""", name, roles); - // Grant API did not exist before 7.7.0 - final boolean grantApiKey = randomBoolean() && isOriginalClusterVersionAtLeast(Version.V_7_7_0); + + final boolean grantApiKey = randomBoolean(); if (grantApiKey) { createApiKeyRequest = new Request("POST", "/_security/api_key/grant"); createApiKeyRequest.setJsonEntity(org.elasticsearch.common.Strings.format(""" @@ -220,16 +222,16 @@ private void updateOrBulkUpdateApiKey(String id, String roles) throws IOExceptio private boolean isUpdateApiSupported(RestClient client) { return switch (CLUSTER_TYPE) { - case OLD -> isOriginalClusterVersionAtLeast(Version.V_8_4_0); // Update API was introduced in 8.4.0. - case MIXED -> isOriginalClusterVersionAtLeast(Version.V_8_4_0) || client == newVersionClient; + case OLD -> clusterHasFeature(RestTestLegacyFeatures.SECURITY_UPDATE_API_KEY); // Update API was introduced in 8.4.0. + case MIXED -> clusterHasFeature(RestTestLegacyFeatures.SECURITY_UPDATE_API_KEY) || client == newVersionClient; case UPGRADED -> true; }; } private boolean isBulkUpdateApiSupported(RestClient client) { return switch (CLUSTER_TYPE) { - case OLD -> isOriginalClusterVersionAtLeast(Version.V_8_5_0); // Bulk update API was introduced in 8.5.0. - case MIXED -> isOriginalClusterVersionAtLeast(Version.V_8_5_0) || client == newVersionClient; + case OLD -> clusterHasFeature(RestTestLegacyFeatures.SECURITY_BULK_UPDATE_API_KEY); // Bulk update API was introduced in 8.5.0. + case MIXED -> clusterHasFeature(RestTestLegacyFeatures.SECURITY_BULK_UPDATE_API_KEY) || client == newVersionClient; case UPGRADED -> true; }; } @@ -304,10 +306,21 @@ private static String randomRoleDescriptors(boolean includeRemoteIndices) { } boolean nodeSupportApiKeyRemoteIndices(Map nodeDetails) { - // TODO[lor]: the method can be kept, but we need to replace version check with features checks - String versionString = (String) nodeDetails.get("version"); - Version version = Version.fromString(versionString.replace("-SNAPSHOT", "")); - return version.onOrAfter(API_KEY_SUPPORT_REMOTE_INDICES_VERSION); + String nodeVersionString = (String) nodeDetails.get("version"); + TransportVersion transportVersion = getTransportVersionWithFallback( + nodeVersionString, + nodeDetails.get("transport_version"), + () -> TransportVersions.ZERO + ); + + if (transportVersion.equals(TransportVersions.ZERO)) { + // In cases where we were not able to find a TransportVersion, a pre-8.8.0 node answered about a newer (upgraded) node. + // In that case, the node will be current (upgraded), and remote indices are supported for sure. + var nodeIsCurrent = nodeVersionString.equals(Build.current().version()); + assertTrue(nodeIsCurrent); + return true; + } + return transportVersion.onOrAfter(RemoteClusterPortSettings.TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY); } private void createClientsByVersion() throws IOException { diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java index e1845e901447e..d935672e0a243 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.client.Response; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Strings; +import org.elasticsearch.core.UpdateForV9; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; @@ -98,7 +99,10 @@ public void removeLogging() throws IOException { } public void testTrainedModelDeployment() throws Exception { - assumeTrue("NLP model deployments added in 8.0", isOriginalClusterVersionAtLeast(Version.V_8_0_0)); + @UpdateForV9 // upgrade will always be from v8, condition can be removed + var originalClusterAtLeastV8 = isOriginalClusterVersionAtLeast(Version.V_8_0_0); + // These tests assume the original cluster is v8 - testing for features on the _current_ cluster will break for NEW + assumeTrue("NLP model deployments added in 8.0", originalClusterAtLeastV8); final String modelId = "upgrade-deployment-test"; @@ -134,7 +138,10 @@ public void testTrainedModelDeployment() throws Exception { } public void testTrainedModelDeploymentStopOnMixedCluster() throws Exception { - assumeTrue("NLP model deployments added in 8.0", isOriginalClusterVersionAtLeast(Version.V_8_0_0)); + @UpdateForV9 // upgrade will always be from v8, condition can be removed + var originalClusterAtLeastV8 = isOriginalClusterVersionAtLeast(Version.V_8_0_0); + // These tests assume the original cluster is v8 - testing for features on the _current_ cluster will break for NEW + assumeTrue("NLP model deployments added in 8.0", originalClusterAtLeastV8); final String modelId = "upgrade-deployment-test-stop-mixed-cluster"; diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java index f1a72663aaf82..657a51dfe1b95 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MlAssignmentPlannerUpgradeIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Strings; +import org.elasticsearch.core.UpdateForV9; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -67,7 +68,10 @@ public class MlAssignmentPlannerUpgradeIT extends AbstractUpgradeTestCase { @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/101926") public void testMlAssignmentPlannerUpgrade() throws Exception { - assumeTrue("NLP model deployments added in 8.0", isOriginalClusterVersionAtLeast(Version.V_8_0_0)); + @UpdateForV9 // upgrade will always be from v8, condition can be removed + var originalClusterAtLeastV8 = isOriginalClusterVersionAtLeast(Version.V_8_0_0); + // These tests assume the original cluster is v8 - testing for features on the _current_ cluster will break for NEW + assumeTrue("NLP model deployments added in 8.0", originalClusterAtLeastV8); assumeFalse("This test deploys multiple models which cannot be accommodated on a single processor", IS_SINGLE_PROCESSOR_TEST); logger.info("Starting testMlAssignmentPlannerUpgrade, model size {}", RAW_MODEL_SIZE); diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SearchableSnapshotsRollingUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SearchableSnapshotsRollingUpgradeIT.java index 0f25592493a1c..0c9827f649170 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SearchableSnapshotsRollingUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SearchableSnapshotsRollingUpgradeIT.java @@ -11,7 +11,6 @@ import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpPut; -import org.elasticsearch.Version; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.common.Strings; @@ -160,13 +159,13 @@ private void executeBlobCacheCreationTestCase(Storage storage, long numberOfDocs final var newVersionNodes = nodesIdsAndVersions.entrySet() .stream() - .filter(node -> UPGRADE_FROM_VERSION.equals(node.getValue()) == false) + .filter(node -> isOriginalCluster(node.getValue()) == false) .map(Map.Entry::getKey) .collect(Collectors.toSet()); final var originalVersionNodes = nodesIdsAndVersions.entrySet() .stream() - .filter(node -> UPGRADE_FROM_VERSION.equals(node.getValue())) + .filter(node -> isOriginalCluster(node.getValue())) .map(Map.Entry::getKey) .collect(Collectors.toSet()); @@ -288,27 +287,22 @@ private void executeBlobCacheCreationTestCase(Storage storage, long numberOfDocs assertHitCount(index, equalTo(numberOfDocs * 2L)); deleteIndex(index); - if (isOriginalClusterVersionAtLeast(Version.V_7_13_0)) { - final Request request = new Request( - "GET", - "/.snapshot-blob-cache/_settings/index.routing.allocation.include._tier_preference" - ); - request.setOptions( - expectWarnings( - "this request accesses system indices: [.snapshot-blob-cache], but in a future major " - + "version, direct access to system indices will be prevented by default" - ) - ); - request.addParameter("flat_settings", "true"); + final Request request = new Request("GET", "/.snapshot-blob-cache/_settings/index.routing.allocation.include._tier_preference"); + request.setOptions( + expectWarnings( + "this request accesses system indices: [.snapshot-blob-cache], but in a future major " + + "version, direct access to system indices will be prevented by default" + ) + ); + request.addParameter("flat_settings", "true"); - final Map snapshotBlobCacheSettings = entityAsMap(adminClient().performRequest(request)); - assertThat(snapshotBlobCacheSettings, notNullValue()); - final String tierPreference = (String) extractValue( - ".snapshot-blob-cache.settings.index.routing.allocation.include._tier_preference", - snapshotBlobCacheSettings - ); - assertThat(tierPreference, equalTo("data_content,data_hot")); - } + final Map snapshotBlobCacheSettings = entityAsMap(adminClient().performRequest(request)); + assertThat(snapshotBlobCacheSettings, notNullValue()); + final String tierPreference = (String) extractValue( + ".snapshot-blob-cache.settings.index.routing.allocation.include._tier_preference", + snapshotBlobCacheSettings + ); + assertThat(tierPreference, equalTo("data_content,data_hot")); } else if (CLUSTER_TYPE.equals(ClusterType.UPGRADED)) { for (String snapshot : snapshots) { diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TransformSurvivesUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TransformSurvivesUpgradeIT.java index c24665d812db6..78ee66fa4d327 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TransformSurvivesUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TransformSurvivesUpgradeIT.java @@ -235,7 +235,7 @@ private void verifyContinuousTransformHandlesData(long expectedLastCheckpoint) t private void verifyUpgradeFailsIfMixedCluster() { // upgrade tests by design are also executed with the same version, this check must be skipped in this case, see gh#39102. - if (UPGRADE_FROM_VERSION.equals(Build.current().version())) { + if (isOriginalCluster(Build.current().version())) { return; } final Request upgradeTransformRequest = new Request("POST", getTransformEndpoint() + "_upgrade");