Skip to content

Commit

Permalink
[ML] Add platform_architecture to package config
Browse files Browse the repository at this point in the history
Adds the new platform_architecture field from elastic#99584
to the package config used when downloading Elastic
models from GCS.
  • Loading branch information
droberts195 committed Oct 3, 2023
1 parent 1c86da0 commit d849398
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 6 deletions.
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ static TransportVersion def(int id) {
public static final TransportVersion INFERENCE_MODEL_SECRETS_ADDED = def(8_509_00_0);
public static final TransportVersion NODE_INFO_REQUEST_SIMPLIFIED = def(8_510_00_0);
public static final TransportVersion NESTED_KNN_VECTOR_QUERY_V = def(8_511_00_0);
public static final TransportVersion ML_PACKAGE_LOADER_PLATFORM_ADDED = def(8_512_00_0);

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -38,6 +39,7 @@ public class ModelPackageConfig implements ToXContentObject, Writeable {
public static final ParseField SIZE = new ParseField("size");
public static final ParseField CHECKSUM_SHA256 = new ParseField("sha256");
public static final ParseField VOCABULARY_FILE = new ParseField("vocabulary_file");
public static final ParseField PLATFORM_ARCHITECTURE = new ParseField("platform_architecture");

private static final ConstructingObjectParser<ModelPackageConfig, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<ModelPackageConfig, Void> STRICT_PARSER = createParser(false);
Expand Down Expand Up @@ -66,7 +68,8 @@ private static ConstructingObjectParser<ModelPackageConfig, Void> createParser(b
metadata,
(String) a[9], // model_type
tags,
(String) a[11] // vocabulary file
(String) a[11], // vocabulary file
(String) a[12] // platform architecture
);
}
);
Expand All @@ -91,6 +94,7 @@ private static ConstructingObjectParser<ModelPackageConfig, Void> createParser(b
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), TrainedModelConfig.MODEL_TYPE);
parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), TrainedModelConfig.TAGS);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), VOCABULARY_FILE);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PLATFORM_ARCHITECTURE);

return parser;
}
Expand All @@ -117,6 +121,7 @@ public static ModelPackageConfig fromXContentLenient(XContentParser parser) thro
private final String modelType;
private final List<String> tags;
private final String vocabularyFile;
private final String platformArchitecture;

public ModelPackageConfig(
String packagedModelId,
Expand All @@ -130,7 +135,8 @@ public ModelPackageConfig(
Map<String, Object> metadata,
String modelType,
List<String> tags,
String vocabularyFile
String vocabularyFile,
String platformArchitecture
) {
this.packagedModelId = ExceptionsHelper.requireNonNull(packagedModelId, PACKAGED_MODEL_ID);
this.modelRepository = modelRepository;
Expand All @@ -147,6 +153,7 @@ public ModelPackageConfig(
this.modelType = modelType;
this.tags = tags == null ? Collections.emptyList() : Collections.unmodifiableList(tags);
this.vocabularyFile = vocabularyFile;
this.platformArchitecture = platformArchitecture;
}

public ModelPackageConfig(StreamInput in) throws IOException {
Expand All @@ -162,6 +169,11 @@ public ModelPackageConfig(StreamInput in) throws IOException {
this.modelType = in.readOptionalString();
this.tags = in.readOptionalCollectionAsList(StreamInput::readString);
this.vocabularyFile = in.readOptionalString();
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_PACKAGE_LOADER_PLATFORM_ADDED)) {
this.platformArchitecture = in.readOptionalString();
} else {
platformArchitecture = null;
}
}

public String getPackagedModelId() {
Expand Down Expand Up @@ -212,6 +224,10 @@ public String getVocabularyFile() {
return vocabularyFile;
}

public String getPlatformArchitecture() {
return platformArchitecture;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down Expand Up @@ -249,6 +265,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (Strings.isNullOrEmpty(vocabularyFile) == false) {
builder.field(VOCABULARY_FILE.getPreferredName(), vocabularyFile);
}
if (Strings.isNullOrEmpty(platformArchitecture) == false) {
builder.field(PLATFORM_ARCHITECTURE.getPreferredName(), platformArchitecture);
}

builder.endObject();
return builder;
Expand All @@ -268,6 +287,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(modelType);
out.writeOptionalStringCollection(tags);
out.writeOptionalString(vocabularyFile);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_PACKAGE_LOADER_PLATFORM_ADDED)) {
out.writeOptionalString(platformArchitecture);
}
}

@Override
Expand All @@ -290,7 +312,8 @@ public boolean equals(Object o) {
&& Objects.equals(metadata, that.metadata)
&& Objects.equals(modelType, that.modelType)
&& Objects.equals(tags, that.tags)
&& Objects.equals(vocabularyFile, that.vocabularyFile);
&& Objects.equals(vocabularyFile, that.vocabularyFile)
&& Objects.equals(platformArchitecture, that.platformArchitecture);
}

@Override
Expand All @@ -307,7 +330,8 @@ public int hashCode() {
metadata,
modelType,
tags,
vocabularyFile
vocabularyFile,
platformArchitecture
);
}

Expand All @@ -330,6 +354,7 @@ public static class Builder {
private String modelType;
private List<String> tags;
private String vocabularyFile;
private String platformArchitecture;

public Builder(ModelPackageConfig modelPackageConfig) {
this.packagedModelId = modelPackageConfig.packagedModelId;
Expand All @@ -344,6 +369,7 @@ public Builder(ModelPackageConfig modelPackageConfig) {
this.modelType = modelPackageConfig.modelType;
this.tags = modelPackageConfig.tags;
this.vocabularyFile = modelPackageConfig.vocabularyFile;
this.platformArchitecture = modelPackageConfig.platformArchitecture;
}

public Builder setPackedModelId(String packagedModelId) {
Expand Down Expand Up @@ -406,6 +432,11 @@ public Builder setVocabularyFile(String vocabularyFile) {
return this;
}

public Builder setPlatformArchitecture(String platformArchitecture) {
this.platformArchitecture = platformArchitecture;
return this;
}

/**
* Reset all fields which are only part of the package metadata, but not be part
* of the config.
Expand Down Expand Up @@ -441,7 +472,8 @@ public ModelPackageConfig build() {
metadata,
modelType,
tags,
vocabularyFile
vocabularyFile,
platformArchitecture
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentHelper;
Expand Down Expand Up @@ -43,12 +44,13 @@ public static ModelPackageConfig randomModulePackageConfig() {
randomBoolean() ? Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)) : null,
randomFrom(TrainedModelType.values()).toString(),
randomBoolean() ? Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false)) : null,
randomBoolean() ? randomAlphaOfLength(10) : null,
randomBoolean() ? randomAlphaOfLength(10) : null
);
}

public static ModelPackageConfig mutateModelPackageConfig(ModelPackageConfig instance) {
switch (between(0, 11)) {
switch (between(0, 12)) {
case 0:
return new ModelPackageConfig.Builder(instance).setPackedModelId(randomAlphaOfLength(15)).build();
case 1:
Expand Down Expand Up @@ -83,6 +85,8 @@ public static ModelPackageConfig mutateModelPackageConfig(ModelPackageConfig ins
).build();
case 11:
return new ModelPackageConfig.Builder(instance).setVocabularyFile(randomAlphaOfLength(15)).build();
case 12:
return new ModelPackageConfig.Builder(instance).setPlatformArchitecture(randomAlphaOfLength(15)).build();
default:
throw new AssertionError("Illegal randomisation branch");
}
Expand Down Expand Up @@ -110,6 +114,9 @@ protected ModelPackageConfig mutateInstance(ModelPackageConfig instance) {

@Override
protected ModelPackageConfig mutateInstanceForVersion(ModelPackageConfig instance, TransportVersion version) {
if (version.before(TransportVersions.ML_PACKAGE_LOADER_PLATFORM_ADDED)) {
return new ModelPackageConfig.Builder(instance).setPlatformArchitecture(null).build();
}
return instance;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ static void setTrainedModelConfigFieldsFromPackagedModel(
) throws IOException {
trainedModelConfig.setDescription(resolvedModelPackageConfig.getDescription());
trainedModelConfig.setModelType(TrainedModelType.fromString(resolvedModelPackageConfig.getModelType()));
trainedModelConfig.setPlatformArchitecture(resolvedModelPackageConfig.getPlatformArchitecture());
trainedModelConfig.setMetadata(resolvedModelPackageConfig.getMetadata());
trainedModelConfig.setInferenceConfig(
parseInferenceConfigFromModelPackage(
Expand Down

0 comments on commit d849398

Please sign in to comment.