Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add custom rule parameters to force time shift #110974

Merged
merged 11 commits into from
Jul 25, 2024
5 changes: 5 additions & 0 deletions docs/changelog/110974.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 110974
summary: Add custom rule parameters to force time shift
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NODES_STATS_ENUM_SET = def(8_709_00_0);
public static final TransportVersion MASTER_NODE_METRICS = def(8_710_00_0);
public static final TransportVersion SEGMENT_LEVEL_FIELDS_STATS = def(8_711_00_0);
public static final TransportVersion ML_ADD_DETECTION_RULE_PARAMS = def(8_712_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.core.ml.job.config;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -30,6 +31,7 @@ public class DetectionRule implements ToXContentObject, Writeable {
public static final ParseField ACTIONS_FIELD = new ParseField("actions");
public static final ParseField SCOPE_FIELD = new ParseField("scope");
public static final ParseField CONDITIONS_FIELD = new ParseField("conditions");
public static final ParseField PARAMS_FIELD = new ParseField("params");

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<Builder, Void> LENIENT_PARSER = createParser(true);
Expand All @@ -45,31 +47,42 @@ private static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFie
ignoreUnknownFields ? RuleCondition.LENIENT_PARSER : RuleCondition.STRICT_PARSER,
CONDITIONS_FIELD
);
parser.declareObject(Builder::setParams, ignoreUnknownFields ? RuleParams.LENIENT_PARSER : RuleParams.STRICT_PARSER, PARAMS_FIELD);

return parser;
}

private final EnumSet<RuleAction> actions;
private final RuleScope scope;
private final List<RuleCondition> conditions;
private final RuleParams params;

private DetectionRule(EnumSet<RuleAction> actions, RuleScope scope, List<RuleCondition> conditions) {
private DetectionRule(EnumSet<RuleAction> actions, RuleScope scope, List<RuleCondition> conditions, RuleParams params) {
this.actions = Objects.requireNonNull(actions);
this.scope = Objects.requireNonNull(scope);
this.conditions = Collections.unmodifiableList(conditions);
this.params = params;
}

public DetectionRule(StreamInput in) throws IOException {
actions = in.readEnumSet(RuleAction.class);
scope = new RuleScope(in);
conditions = in.readCollectionAsList(RuleCondition::new);
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_ADD_DETECTION_RULE_PARAMS)) {
params = new RuleParams(in);
} else {
params = new RuleParams();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnumSet(actions);
scope.writeTo(out);
out.writeCollection(conditions);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_ADD_DETECTION_RULE_PARAMS)) {
params.writeTo(out);
}
}

@Override
Expand All @@ -82,6 +95,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (conditions.isEmpty() == false) {
builder.field(CONDITIONS_FIELD.getPreferredName(), conditions);
}
if (this.params.isEmpty() == false) {
builder.field(PARAMS_FIELD.getPreferredName(), this.params);
}
builder.endObject();
return builder;
}
Expand All @@ -98,6 +114,10 @@ public List<RuleCondition> getConditions() {
return conditions;
}

public RuleParams getParams() {
return params;
}

public Set<String> extractReferencedFilters() {
return scope.getReferencedFilters();
}
Expand All @@ -113,18 +133,22 @@ public boolean equals(Object obj) {
}

DetectionRule other = (DetectionRule) obj;
return Objects.equals(actions, other.actions) && Objects.equals(scope, other.scope) && Objects.equals(conditions, other.conditions);
return Objects.equals(actions, other.actions)
&& Objects.equals(scope, other.scope)
&& Objects.equals(conditions, other.conditions)
&& Objects.equals(params, other.params);
}

@Override
public int hashCode() {
return Objects.hash(actions, scope, conditions);
return Objects.hash(actions, scope, conditions, params);
}

public static class Builder {
private EnumSet<RuleAction> actions = EnumSet.of(RuleAction.SKIP_RESULT);
private RuleScope scope = new RuleScope();
private List<RuleCondition> conditions = Collections.emptyList();
private RuleParams params = new RuleParams();

public Builder(RuleScope.Builder scope) {
this.scope = scope.build();
Expand Down Expand Up @@ -163,12 +187,27 @@ public Builder setConditions(List<RuleCondition> conditions) {
return this;
}

public Builder setParams(RuleParams params) {
this.params = params;
return this;
}

public DetectionRule build() {
if (scope.isEmpty() && conditions.isEmpty()) {
String msg = Messages.getMessage(Messages.JOB_CONFIG_DETECTION_RULE_REQUIRES_SCOPE_OR_CONDITION);
throw ExceptionsHelper.badRequestException(msg);
}
return new DetectionRule(actions, scope, conditions);
// if actions contain FORCE_TIME_SHIFT, then params must contain RuleParamsForForceTimeShift
if (actions.contains(RuleAction.FORCE_TIME_SHIFT) && params.getForceTimeShift() == null) {
String msg = Messages.getMessage(Messages.JOB_CONFIG_DETECTION_RULE_REQUIRES_FORCE_TIME_SHIFT_PARAMS);
throw ExceptionsHelper.badRequestException(msg);
}
// Return error if params must contain RuleParamsForForceTimeShift, but actions do not contain FORCE_TIME_SHIFT
if (actions.contains(RuleAction.FORCE_TIME_SHIFT) == false && params.getForceTimeShift() != null) {
String msg = Messages.getMessage(Messages.JOB_CONFIG_DETECTION_RULE_PARAMS_FORCE_TIME_SHIFT_NOT_REQUIRED);
throw ExceptionsHelper.badRequestException(msg);
}
return new DetectionRule(actions, scope, conditions, params);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

public enum RuleAction implements Writeable {
SKIP_RESULT,
SKIP_MODEL_UPDATE;
SKIP_MODEL_UPDATE,
FORCE_TIME_SHIFT;

/**
* Case-insensitive from string method.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.job.config;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class RuleParams implements ToXContentObject, Writeable {

public static final ParseField RULE_PARAMS_FIELD = new ParseField("params");
public static final ParseField FORCE_TIME_SHIFT_FIELD = new ParseField("force_time_shift");

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ConstructingObjectParser<RuleParams, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<RuleParams, Void> STRICT_PARSER = createParser(false);

public static ConstructingObjectParser<RuleParams, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<RuleParams, Void> parser = new ConstructingObjectParser<>(
RULE_PARAMS_FIELD.getPreferredName(),
ignoreUnknownFields,
a -> new RuleParams((RuleParamsForForceTimeShift) a[0])
);

parser.declareObject(
ConstructingObjectParser.optionalConstructorArg(),
RuleParamsForForceTimeShift.LENIENT_PARSER,
FORCE_TIME_SHIFT_FIELD
);
return parser;
}

private final RuleParamsForForceTimeShift forceTimeShift;

public RuleParams() {
this.forceTimeShift = null;
}

public RuleParams(RuleParamsForForceTimeShift forceTimeShift) {
this.forceTimeShift = forceTimeShift;
}

public RuleParams(StreamInput in) throws IOException {
// initialize optional forceTimeShift from in
forceTimeShift = in.readOptionalWriteable(RuleParamsForForceTimeShift::new);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
// write optional forceTimeShift to out
out.writeOptionalWriteable(forceTimeShift);
}

boolean isEmpty() {
return forceTimeShift == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (forceTimeShift != null) {
builder.field(FORCE_TIME_SHIFT_FIELD.getPreferredName(), forceTimeShift);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}

if (obj instanceof RuleParams == false) {
return false;
}

RuleParams other = (RuleParams) obj;
return Objects.equals(forceTimeShift, other.forceTimeShift);
}

@Override
public int hashCode() {
return Objects.hash(forceTimeShift);
}

public RuleParamsForForceTimeShift getForceTimeShift() {
return forceTimeShift;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.job.config;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;

public class RuleParamsForForceTimeShift implements ToXContentObject, Writeable {
public static final ParseField TYPE_FIELD = new ParseField("force_time_shift_params");
public static final ParseField TIME_SHIFT_AMOUNT_FIELD = new ParseField("time_shift_amount");

public static final ConstructingObjectParser<RuleParamsForForceTimeShift, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<RuleParamsForForceTimeShift, Void> STRICT_PARSER = createParser(false);

private static ConstructingObjectParser<RuleParamsForForceTimeShift, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<RuleParamsForForceTimeShift, Void> parser = new ConstructingObjectParser<>(
TYPE_FIELD.getPreferredName(),
ignoreUnknownFields,
a -> new RuleParamsForForceTimeShift((Long) a[0])
);
parser.declareLong(ConstructingObjectParser.constructorArg(), TIME_SHIFT_AMOUNT_FIELD);
return parser;
}

private final long timeShiftAmount;

public RuleParamsForForceTimeShift(long timeShiftAmount) {
this.timeShiftAmount = timeShiftAmount;
}

public RuleParamsForForceTimeShift(StreamInput in) throws IOException {
timeShiftAmount = in.readLong();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeLong(timeShiftAmount);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TIME_SHIFT_AMOUNT_FIELD.getPreferredName(), timeShiftAmount);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}

if (obj instanceof RuleParamsForForceTimeShift == false) {
return false;
}

RuleParamsForForceTimeShift other = (RuleParamsForForceTimeShift) obj;
return timeShiftAmount == other.timeShiftAmount;
}

@Override
public int hashCode() {
return Long.hashCode(timeShiftAmount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ public final class Messages {
"Invalid detector rule: function {0} only supports conditions that apply to time";
public static final String JOB_CONFIG_DETECTION_RULE_REQUIRES_SCOPE_OR_CONDITION =
"Invalid detector rule: at least scope or a condition is required";
public static final String JOB_CONFIG_DETECTION_RULE_REQUIRES_FORCE_TIME_SHIFT_PARAMS =
"Invalid detector rule: actions contain force_time_shift, but corresponding parameters are missing";
public static final String JOB_CONFIG_DETECTION_RULE_PARAMS_FORCE_TIME_SHIFT_NOT_REQUIRED =
"Invalid detector rule: actions do not contain force_time_shift, but corresponding parameters are present";
public static final String JOB_CONFIG_DETECTION_RULE_SCOPE_NO_AVAILABLE_FIELDS =
"Invalid detector rule: scope field ''{0}'' is invalid; detector has no available fields for scoping";
public static final String JOB_CONFIG_DETECTION_RULE_SCOPE_HAS_INVALID_FIELD =
Expand Down
Loading