Skip to content

Commit

Permalink
[ML] Introduce randomize_seed setting for regression and classificati…
Browse files Browse the repository at this point in the history
…on (#49990)

This adds a new `randomize_seed` for regression and classification.
When not explicitly set, the seed is randomly generated. One can
reuse the seed in a similar job in order to ensure the same docs
are picked for training.
  • Loading branch information
dimitris-athanasiou authored Dec 10, 2019
1 parent a6351d6 commit 269425b
Show file tree
Hide file tree
Showing 24 changed files with 460 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public static Builder builder(String dependentVariable) {
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");

private static final ConstructingObjectParser<Classification, Void> PARSER =
new ConstructingObjectParser<>(
Expand All @@ -63,7 +64,8 @@ public static Builder builder(String dependentVariable) {
(Double) a[5],
(String) a[6],
(Double) a[7],
(Integer) a[8]));
(Integer) a[8],
(Long) a[9]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
Expand All @@ -75,6 +77,7 @@ public static Builder builder(String dependentVariable) {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
}

private final String dependentVariable;
Expand All @@ -86,10 +89,11 @@ public static Builder builder(String dependentVariable) {
private final String predictionFieldName;
private final Double trainingPercent;
private final Integer numTopClasses;
private final Long randomizeSeed;

private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
Expand All @@ -99,6 +103,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.numTopClasses = numTopClasses;
this.randomizeSeed = randomizeSeed;
}

@Override
Expand Down Expand Up @@ -138,6 +143,10 @@ public Double getTrainingPercent() {
return trainingPercent;
}

public Long getRandomizeSeed() {
return randomizeSeed;
}

public Integer getNumTopClasses() {
return numTopClasses;
}
Expand Down Expand Up @@ -167,6 +176,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
if (randomizeSeed != null) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
Expand All @@ -177,7 +189,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent, numTopClasses);
trainingPercent, randomizeSeed, numTopClasses);
}

@Override
Expand All @@ -193,6 +205,7 @@ public boolean equals(Object o) {
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed)
&& Objects.equals(numTopClasses, that.numTopClasses);
}

Expand All @@ -211,6 +224,7 @@ public static class Builder {
private String predictionFieldName;
private Double trainingPercent;
private Integer numTopClasses;
private Long randomizeSeed;

private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
Expand Down Expand Up @@ -251,14 +265,19 @@ public Builder setTrainingPercent(Double trainingPercent) {
return this;
}

public Builder setRandomizeSeed(Long randomizeSeed) {
this.randomizeSeed = randomizeSeed;
return this;
}

public Builder setNumTopClasses(Integer numTopClasses) {
this.numTopClasses = numTopClasses;
return this;
}

public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent, numTopClasses);
trainingPercent, numTopClasses, randomizeSeed);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) {
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");

private static final ConstructingObjectParser<Regression, Void> PARSER =
new ConstructingObjectParser<>(
Expand All @@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) {
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));
(Double) a[7],
(Long) a[8]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
Expand All @@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) {
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
}

private final String dependentVariable;
Expand All @@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) {
private final Double featureBagFraction;
private final String predictionFieldName;
private final Double trainingPercent;
private final Long randomizeSeed;

private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
Expand All @@ -94,6 +98,7 @@ private Regression(String dependentVariable, @Nullable Double lambda, @Nullable
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
this.randomizeSeed = randomizeSeed;
}

@Override
Expand Down Expand Up @@ -133,6 +138,10 @@ public Double getTrainingPercent() {
return trainingPercent;
}

public Long getRandomizeSeed() {
return randomizeSeed;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -158,14 +167,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
if (randomizeSeed != null) {
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
trainingPercent, randomizeSeed);
}

@Override
Expand All @@ -180,7 +192,8 @@ public boolean equals(Object o) {
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent);
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed);
}

@Override
Expand All @@ -197,6 +210,7 @@ public static class Builder {
private Double featureBagFraction;
private String predictionFieldName;
private Double trainingPercent;
private Long randomizeSeed;

private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
Expand Down Expand Up @@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) {
return this;
}

public Builder setRandomizeSeed(Long randomizeSeed) {
this.randomizeSeed = randomizeSeed;
return this;
}

public Regression build() {
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
trainingPercent, randomizeSeed);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setRandomizeSeed(42L)
.build())
.setDescription("this is a regression")
.build();
Expand Down Expand Up @@ -1326,6 +1327,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setRandomizeSeed(42L)
.setNumTopClasses(1)
.build())
.setDescription("this is a classification")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2975,7 +2975,8 @@ public void testPutDataFrameAnalytics() throws Exception {
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.setNumTopClasses(1) // <9>
.setRandomizeSeed(1234L) // <9>
.setNumTopClasses(1) // <10>
.build();
// end::put-data-frame-analytics-classification

Expand All @@ -2988,6 +2989,7 @@ public void testPutDataFrameAnalytics() throws Exception {
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.setRandomizeSeed(1234L) // <9>
.build();
// end::put-data-frame-analytics-regression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public static Classification randomClassification() {
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setRandomizeSeed(randomBoolean() ? null : randomLong())
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
<7> The name of the prediction field in the results object.
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<9> The number of top classes to be reported in the results. Defaults to 2.
<9> The seed to be used by the random generator that picks which rows are used in training.
<10> The number of top classes to be reported in the results. Defaults to 2.

===== Regression

Expand All @@ -138,6 +139,7 @@ include-tagged::{doc-tests-file}[{api}-regression]
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
<7> The name of the prediction field in the results object.
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<9> The seed to be used by the random generator that picks which rows are used in training.

==== Analyzed fields

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]

include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]

include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]


[float]
[[regression-resources-advanced]]
Expand Down Expand Up @@ -252,6 +254,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]

include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]

include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]


[float]
[[classification-resources-advanced]]
Expand Down
4 changes: 3 additions & 1 deletion docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
{
"regression": {
"dependent_variable": "G3",
"training_percent": 70 <1>
"training_percent": 70, <1>
"randomize_seed": 19673948271 <2>
}
}
}
Expand All @@ -406,6 +407,7 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3

<1> The `training_percent` defines the percentage of the data set that will be used
for training the model.
<2> The `randomize_seed` is the seed used to randomly pick which data is used for training.


[[ml-put-dfanalytics-example-c]]
Expand Down
9 changes: 9 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,15 @@ those that contain arrays) won’t be included in the calculation for used
percentage. Defaults to `100`.
end::training_percent[]

tag::randomize_seed[]
`randomize_seed`::
(Optional, long) Defines the seed to the random generator that is used to pick
which documents will be used for training. By default it is randomly generated.
Set it to a specific value to ensure the same documents are used for training
assuming other related parameters (e.g. `source`, `analyzed_fields`, etc.) are the same.
end::randomize_seed[]


tag::use-null[]
Defines whether a new series is used as the null series when there is no value
for the by or partition fields. The default value is `false`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(DEST.getPreferredName(), dest);

builder.startObject(ANALYSIS.getPreferredName());
builder.field(analysis.getWriteableName(), analysis);
builder.field(analysis.getWriteableName(), analysis,
new MapParams(Collections.singletonMap(VERSION.getPreferredName(), version == null ? null : version.toString())));
builder.endObject();

if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static void declareFields(AbstractObjectParser<?, Void> parser) {
private final Integer maximumNumberTrees;
private final Double featureBagFraction;

BoostedTreeParams(@Nullable Double lambda,
public BoostedTreeParams(@Nullable Double lambda,
@Nullable Double gamma,
@Nullable Double eta,
@Nullable Integer maximumNumberTrees,
Expand All @@ -76,7 +76,7 @@ static void declareFields(AbstractObjectParser<?, Void> parser) {
this.featureBagFraction = featureBagFraction;
}

BoostedTreeParams() {
public BoostedTreeParams() {
this(null, null, null, null, null);
}

Expand Down
Loading

0 comments on commit 269425b

Please sign in to comment.