Skip to content

Commit

Permalink
[ML] adds new for_export flag to GET _ml/inference API (#57351)
Browse files Browse the repository at this point in the history
Adds a new boolean flag, `for_export` to the `GET _ml/inference/<model_id>` API.

This flag is useful for moving models between clusters.
  • Loading branch information
benwtrent authored May 29, 2020
1 parent b483246 commit 251b170
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,9 @@ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest)
if (getTrainedModelsRequest.getTags() != null) {
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
}
if (getTrainedModelsRequest.getForExport() != null) {
params.putParam(GetTrainedModelsRequest.FOR_EXPORT, Boolean.toString(getTrainedModelsRequest.getForExport()));
}
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.addParameters(params.asMap());
return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ public class GetTrainedModelsRequest implements Validatable {

public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String FOR_EXPORT = "for_export";
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TAGS = "tags";

private final List<String> ids;
private Boolean allowNoMatch;
private Boolean includeDefinition;
private Boolean decompressDefinition;
private Boolean forExport;
private PageParams pageParams;
private List<String> tags;

Expand Down Expand Up @@ -137,6 +139,23 @@ public GetTrainedModelsRequest setTags(String... tags) {
return setTags(Arrays.asList(tags));
}

public Boolean getForExport() {
return forExport;
}

/**
* Setting this flag to `true` removes certain fields from the model definition on retrieval.
*
* This is useful when getting the model and wanting to put it in another cluster.
*
* Default value is false.
* @param forExport Boolean value indicating if certain fields should be removed from the mode on GET
*/
public GetTrainedModelsRequest setForExport(Boolean forExport) {
this.forExport = forExport;
return this;
}

@Override
public Optional<ValidationException> validate() {
if (ids == null || ids.isEmpty()) {
Expand All @@ -155,11 +174,12 @@ public boolean equals(Object o) {
&& Objects.equals(allowNoMatch, other.allowNoMatch)
&& Objects.equals(decompressDefinition, other.decompressDefinition)
&& Objects.equals(includeDefinition, other.includeDefinition)
&& Objects.equals(forExport, other.forExport)
&& Objects.equals(pageParams, other.pageParams);
}

@Override
public int hashCode() {
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition);
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3611,7 +3611,8 @@ public void testGetTrainedModels() throws Exception {
.setIncludeDefinition(false) // <3>
.setDecompressDefinition(false) // <4>
.setAllowNoMatch(true) // <5>
.setTags("regression"); // <6>
.setTags("regression") // <6>
.setForExport(false); // <7>
// end::get-trained-models-request
request.setTags((List<String>)null);

Expand Down
3 changes: 3 additions & 0 deletions docs/java-rest/high-level/ml/get-trained-models.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ include-tagged::{doc-tests-file}[{api}-request]
<6> An optional list of tags used to narrow the model search. A Trained Model
can have many tags or none. The trained models in the response will
contain all the provided tags.
<7> Optional boolean value indicating if certain fields should be removed on
retrieval. This is useful for getting the trained model in a format that
can then be put into another cluster.

include::../execution.asciidoc[]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=size]
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]

`for_export`::
(Optional, boolean)
Indicates if certain fields should be removed from the model configuration on
retrieval. This allows the model to be in an acceptable format to be retrieved
and then added to another cluster. Default is false.

[role="child_attributes"]
[[ml-get-inference-results]]
==== {api-response-body-title}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_config";
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String FOR_EXPORT = "for_export";

private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";

Expand Down Expand Up @@ -304,13 +305,22 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(CREATED_BY.getPreferredName(), createdBy);
builder.field(VERSION.getPreferredName(), version.toString());
// If the model is to be exported for future import to another cluster, these fields are irrelevant.
if (params.paramAsBoolean(FOR_EXPORT, false) == false) {
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(CREATED_BY.getPreferredName(), createdBy);
builder.field(VERSION.getPreferredName(), version.toString());
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
builder.humanReadableField(
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
}
if (description != null) {
builder.field(DESCRIPTION.getPreferredName(), description);
}
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
// We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
Expand All @@ -327,12 +337,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
}
builder.field(INPUT.getPreferredName(), input);
builder.humanReadableField(
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString;
Expand Down Expand Up @@ -187,6 +188,43 @@ public void testGetPrePackagedModels() throws IOException {
assertThat(response, containsString("\"definition\""));
}

@SuppressWarnings("unchecked")
public void testExportImportModel() throws IOException {
String modelId = "regression_model_to_export";
putRegressionModel(modelId);
Response getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/" + modelId));

assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
String response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
assertThat(response, containsString("\"count\":1"));

getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH +
"inference/" + modelId +
"?include_model_definition=true&decompress_definition=false&for_export=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));

Map<String, Object> exportedModel = entityAsMap(getModel);
Map<String, Object> modelDefinition = ((List<Map<String, Object>>)exportedModel.get("trained_model_configs")).get(0);

String importedModelId = "regression_model_to_import";
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.map(modelDefinition);
Request model = new Request("PUT", "_ml/inference/" + importedModelId);
model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
}
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/regression*"));

assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
assertThat(response, containsString("\"model_id\":\"regression_model_to_import\""));
assertThat(response, containsString("\"count\":2"));
}

private void putRegressionModel(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient

@Override
protected Set<String> responseParams() {
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
return Set.of(TrainedModelConfig.DECOMPRESS_DEFINITION, TrainedModelConfig.FOR_EXPORT);
}

private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@
"required":false,
"type":"list",
"description":"A comma-separated list of tags that the model must have."
},
"for_export": {
"required": false,
"type": "boolean",
"default": false,
"description": "Omits fields that are illegal to set on model PUT"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,3 +818,24 @@ setup:
}
}
}
---
"Test for_export flag":
- do:
ml.get_trained_models:
model_id: "a-regression-model-1"
for_export: true
include_model_definition: true
decompress_definition: false

- match: { trained_model_configs.0.description: "empty model for tests" }
- is_true: trained_model_configs.0.compressed_definition
- is_true: trained_model_configs.0.input
- is_true: trained_model_configs.0.inference_config
- is_true: trained_model_configs.0.tags
- is_false: trained_model_configs.0.model_id
- is_false: trained_model_configs.0.created_by
- is_false: trained_model_configs.0.version
- is_false: trained_model_configs.0.create_time
- is_false: trained_model_configs.0.estimated_heap_memory_usage
- is_false: trained_model_configs.0.estimated_operations
- is_false: trained_model_configs.0.license_level

0 comments on commit 251b170

Please sign in to comment.