Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] adds new default_field_map field to trained models #53294

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAPPINGS = new ParseField("default_field_mappings");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
Expand All @@ -76,6 +77,7 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMappings, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAPPINGS);
}

public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
Expand All @@ -95,6 +97,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
private final Long estimatedHeapMemory;
private final Long estimatedOperations;
private final String licenseLevel;
private final Map<String, String> defaultFieldMappings;

TrainedModelConfig(String modelId,
String createdBy,
Expand All @@ -108,7 +111,8 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations,
String licenseLevel) {
String licenseLevel,
Map<String, String> defaultFieldMappings) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
Expand All @@ -122,6 +126,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
this.estimatedHeapMemory = estimatedHeapMemory;
this.estimatedOperations = estimatedOperations;
this.licenseLevel = licenseLevel;
this.defaultFieldMappings = defaultFieldMappings == null ? null : Collections.unmodifiableMap(defaultFieldMappings);
}

public String getModelId() {
Expand Down Expand Up @@ -180,6 +185,10 @@ public String getLicenseLevel() {
return licenseLevel;
}

public Map<String, String> getDefaultFieldMappings() {
return defaultFieldMappings;
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -226,6 +235,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (licenseLevel != null) {
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel);
}
if (defaultFieldMappings != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (defaultFieldMappings != null) {
if (defaultFieldMappings != null && defaultFieldMappings.isEmpty() == false) {

I prefer not to write empty collections.

builder.field(DEFAULT_FIELD_MAPPINGS.getPreferredName(), defaultFieldMappings);
}
builder.endObject();
return builder;
}
Expand All @@ -252,6 +264,7 @@ public boolean equals(Object o) {
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMappings, that.defaultFieldMappings) &&
Objects.equals(metadata, that.metadata);
}

Expand All @@ -269,7 +282,8 @@ public int hashCode() {
estimatedOperations,
metadata,
licenseLevel,
input);
input,
defaultFieldMappings);
}


Expand All @@ -288,6 +302,7 @@ public static class Builder {
private Long estimatedHeapMemory;
private Long estimatedOperations;
private String licenseLevel;
private Map<String, String> defaultFieldMappings;

public Builder setModelId(String modelId) {
this.modelId = modelId;
Expand Down Expand Up @@ -367,6 +382,11 @@ private Builder setLicenseLevel(String licenseLevel) {
return this;
}

public Builder setDefaultFieldMappings(Map<String, String> defaultFieldMappings) {
this.defaultFieldMappings = defaultFieldMappings;
return this;
}

public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
Expand All @@ -381,7 +401,8 @@ public TrainedModelConfig build() {
input,
estimatedHeapMemory,
estimatedOperations,
licenseLevel);
licenseLevel,
defaultFieldMappings);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -52,7 +53,11 @@ public static TrainedModelConfig createTestTrainedModelConfig() {
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomFrom("platinum", "basic"));
randomBoolean() ? null : randomFrom("platinum", "basic"),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 neat I never thought of building a random map this way

}

@Override
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ingested in the pipeline.
| Name | Required | Default | Description
| `model_id` | yes | - | (String) The ID of the model to load and infer against.
| `target_field` | no | `ml.inference.<processor_tag>` | (String) Field added to incoming documents to contain results objects.
| `field_mappings` | yes | - | (Object) Maps the document field names to the known field names of the model.
| `field_mappings` | yes | - | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration.
| `inference_config` | yes | - | (Object) Contains the inference type and its options. There are two types: <<inference-processor-regression-opt,`regression`>> and <<inference-processor-classification-opt,`classification`>>.
include::common-options.asciidoc[]
|======
Expand Down
7 changes: 7 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,13 @@ The estimated number of operations to use the trained model.
`license_level`:::
(string)
The license level of the trained model.

`default_field_mappings` :::
(object)
A string to string mapping that contains the default field mappings to use
when inferring against the model. Any field mapping described in the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a line here about multi fields which is the primary reason for this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure!

inference configuration takes precedence.

end::trained-model-configs[]

tag::training-percent[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -60,6 +61,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAPPINGS = new ParseField("default_field_mappings");

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
Expand Down Expand Up @@ -90,6 +92,7 @@ private static ObjectParser<TrainedModelConfig.Builder, Void> createParser(boole
DEFINITION);
parser.declareString(TrainedModelConfig.Builder::setLazyDefinition, COMPRESSED_DEFINITION);
parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
parser.declareObject(TrainedModelConfig.Builder::setDefaultFieldMappings, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAPPINGS);
return parser;
}

Expand All @@ -108,6 +111,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
private final long estimatedHeapMemory;
private final long estimatedOperations;
private final License.OperationMode licenseLevel;
private final Map<String, String> defaultFieldMappings;

private final LazyModelDefinition definition;

Expand All @@ -122,7 +126,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations,
String licenseLevel) {
String licenseLevel,
Map<String, String> defaultFieldMappings) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
Expand All @@ -142,6 +147,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
}
this.estimatedOperations = estimatedOperations;
this.licenseLevel = License.OperationMode.parse(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL));
this.defaultFieldMappings = defaultFieldMappings == null ? null : Collections.unmodifiableMap(defaultFieldMappings);
}

public TrainedModelConfig(StreamInput in) throws IOException {
Expand All @@ -157,6 +163,13 @@ public TrainedModelConfig(StreamInput in) throws IOException {
estimatedHeapMemory = in.readVLong();
estimatedOperations = in.readVLong();
licenseLevel = License.OperationMode.parse(in.readString());
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.defaultFieldMappings = in.readBoolean() ?
Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)) :
null;
} else {
this.defaultFieldMappings = null;
}
}

public String getModelId() {
Expand Down Expand Up @@ -187,6 +200,10 @@ public Map<String, Object> getMetadata() {
return metadata;
}

public Map<String, String> getDefaultFieldMappings() {
return defaultFieldMappings;
}

@Nullable
public String getCompressedDefinition() throws IOException {
if (definition == null) {
Expand Down Expand Up @@ -249,6 +266,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(estimatedHeapMemory);
out.writeVLong(estimatedOperations);
out.writeString(licenseLevel.description());
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
if (defaultFieldMappings != null) {
out.writeBoolean(true);
out.writeMap(defaultFieldMappings, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand Down Expand Up @@ -283,6 +308,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
if (defaultFieldMappings != null) {
builder.field(DEFAULT_FIELD_MAPPINGS.getPreferredName(), defaultFieldMappings);
}
builder.endObject();
return builder;
}
Expand All @@ -308,6 +336,7 @@ public boolean equals(Object o) {
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMappings, that.defaultFieldMappings) &&
Objects.equals(metadata, that.metadata);
}

Expand All @@ -324,7 +353,8 @@ public int hashCode() {
estimatedHeapMemory,
estimatedOperations,
input,
licenseLevel);
licenseLevel,
defaultFieldMappings);
}

public static class Builder {
Expand All @@ -341,6 +371,7 @@ public static class Builder {
private Long estimatedOperations;
private LazyModelDefinition definition;
private String licenseLevel;
private Map<String, String> defaultFieldMappings;

public Builder() {}

Expand All @@ -357,6 +388,7 @@ public Builder(TrainedModelConfig config) {
this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory;
this.licenseLevel = config.licenseLevel.description();
this.defaultFieldMappings = config.defaultFieldMappings == null ? null : new HashMap<>(config.defaultFieldMappings);
}

public Builder setModelId(String modelId) {
Expand Down Expand Up @@ -475,6 +507,11 @@ public Builder setLicenseLevel(String licenseLevel) {
return this;
}

public Builder setDefaultFieldMappings(Map<String, String> defaultFieldMappings) {
this.defaultFieldMappings = defaultFieldMappings;
return this;
}

public Builder validate() {
return validate(false);
}
Expand Down Expand Up @@ -567,7 +604,8 @@ public TrainedModelConfig build() {
input,
estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
estimatedOperations == null ? 0 : estimatedOperations,
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel,
defaultFieldMappings);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
},
"total_definition_length": {
"type": "long"
},
"default_field_mappings": {
"enabled": false
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
Expand Down Expand Up @@ -137,7 +139,11 @@ public void testToXContentWithParams() throws IOException {
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong(),
"platinum");
"platinum",
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));

BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
Expand Down Expand Up @@ -172,7 +178,11 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong(),
"platinum");
"platinum",
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));

BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();
Expand Down
Loading