Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Reimplement established model memory #35263

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.io.IOException;
import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
Expand All @@ -57,24 +58,28 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom {
public static final String TYPE = "ml";
private static final ParseField JOBS_FIELD = new ParseField("jobs");
private static final ParseField DATAFEEDS_FIELD = new ParseField("datafeeds");
private static final ParseField LAST_MEMORY_REFRESH_TIME_FIELD = new ParseField("last_memory_refresh_time");

public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap());
public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap(), null);
// This parser follows the pattern that metadata is parsed leniently (to allow for enhancements)
public static final ObjectParser<Builder, Void> LENIENT_PARSER = new ObjectParser<>("ml_metadata", true, Builder::new);

static {
LENIENT_PARSER.declareObjectArray(Builder::putJobs, (p, c) -> Job.LENIENT_PARSER.apply(p, c).build(), JOBS_FIELD);
LENIENT_PARSER.declareObjectArray(Builder::putDatafeeds,
(p, c) -> DatafeedConfig.LENIENT_PARSER.apply(p, c).build(), DATAFEEDS_FIELD);
LENIENT_PARSER.declareLong(Builder::setLastMemoryRefreshTimeMs, LAST_MEMORY_REFRESH_TIME_FIELD);
}

private final SortedMap<String, Job> jobs;
private final SortedMap<String, DatafeedConfig> datafeeds;
private final Instant lastMemoryRefreshTime;
private final GroupOrJobLookup groupOrJobLookup;

private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds) {
private MlMetadata(SortedMap<String, Job> jobs, SortedMap<String, DatafeedConfig> datafeeds, Instant lastMemoryRefreshTime) {
this.jobs = Collections.unmodifiableSortedMap(jobs);
this.datafeeds = Collections.unmodifiableSortedMap(datafeeds);
this.lastMemoryRefreshTime = lastMemoryRefreshTime;
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

Expand Down Expand Up @@ -112,6 +117,10 @@ public Set<String> expandDatafeedIds(String expression, boolean allowNoDatafeeds
.expand(expression, allowNoDatafeeds);
}

public Instant getLastMemoryRefreshTime() {
return lastMemoryRefreshTime;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.V_5_4_0;
Expand Down Expand Up @@ -145,14 +154,27 @@ public MlMetadata(StreamInput in) throws IOException {
datafeeds.put(in.readString(), new DatafeedConfig(in));
}
this.datafeeds = datafeeds;

if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshTime = in.readBoolean() ? Instant.ofEpochSecond(in.readVLong(), in.readVInt()) : null;
} else {
lastMemoryRefreshTime = null;
}
this.groupOrJobLookup = new GroupOrJobLookup(jobs.values());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
writeMap(jobs, out);
writeMap(datafeeds, out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
if (lastMemoryRefreshTime == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeVLong(lastMemoryRefreshTime.getEpochSecond());
out.writeVInt(lastMemoryRefreshTime.getNano());
}
}
}

private static <T extends Writeable> void writeMap(Map<String, T> map, StreamOutput out) throws IOException {
Expand All @@ -169,6 +191,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
new DelegatingMapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"), params);
mapValuesToXContent(JOBS_FIELD, jobs, builder, extendedParams);
mapValuesToXContent(DATAFEEDS_FIELD, datafeeds, builder, extendedParams);
if (lastMemoryRefreshTime != null) {
// We lose precision lower than milliseconds here - OK as millisecond precision is adequate for this use case
builder.timeField(LAST_MEMORY_REFRESH_TIME_FIELD.getPreferredName(),
LAST_MEMORY_REFRESH_TIME_FIELD.getPreferredName() + "_string", lastMemoryRefreshTime.toEpochMilli());
}
return builder;
}

Expand All @@ -185,30 +212,47 @@ public static class MlMetadataDiff implements NamedDiff<MetaData.Custom> {

final Diff<Map<String, Job>> jobs;
final Diff<Map<String, DatafeedConfig>> datafeeds;
final Instant lastMemoryRefreshTime;

MlMetadataDiff(MlMetadata before, MlMetadata after) {
this.jobs = DiffableUtils.diff(before.jobs, after.jobs, DiffableUtils.getStringKeySerializer());
this.datafeeds = DiffableUtils.diff(before.datafeeds, after.datafeeds, DiffableUtils.getStringKeySerializer());
this.lastMemoryRefreshTime = after.lastMemoryRefreshTime;
}

public MlMetadataDiff(StreamInput in) throws IOException {
this.jobs = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), Job::new,
MlMetadataDiff::readJobDiffFrom);
this.datafeeds = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), DatafeedConfig::new,
MlMetadataDiff::readSchedulerDiffFrom);
if (in.getVersion().onOrAfter(Version.V_6_6_0)) {
lastMemoryRefreshTime = in.readBoolean() ? Instant.ofEpochSecond(in.readVLong(), in.readVInt()) : null;
} else {
lastMemoryRefreshTime = null;
}
}

@Override
public MetaData.Custom apply(MetaData.Custom part) {
TreeMap<String, Job> newJobs = new TreeMap<>(jobs.apply(((MlMetadata) part).jobs));
TreeMap<String, DatafeedConfig> newDatafeeds = new TreeMap<>(datafeeds.apply(((MlMetadata) part).datafeeds));
return new MlMetadata(newJobs, newDatafeeds);
Instant lastMemoryRefreshTime = ((MlMetadata) part).lastMemoryRefreshTime;
return new MlMetadata(newJobs, newDatafeeds, lastMemoryRefreshTime);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
jobs.writeTo(out);
datafeeds.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_6_6_0)) {
if (lastMemoryRefreshTime == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeVLong(lastMemoryRefreshTime.getEpochSecond());
out.writeVInt(lastMemoryRefreshTime.getNano());
}
}
}

@Override
Expand All @@ -233,7 +277,8 @@ public boolean equals(Object o) {
return false;
MlMetadata that = (MlMetadata) o;
return Objects.equals(jobs, that.jobs) &&
Objects.equals(datafeeds, that.datafeeds);
Objects.equals(datafeeds, that.datafeeds) &&
Objects.equals(lastMemoryRefreshTime, that.lastMemoryRefreshTime);
}

@Override
Expand All @@ -243,13 +288,14 @@ public final String toString() {

@Override
public int hashCode() {
return Objects.hash(jobs, datafeeds);
return Objects.hash(jobs, datafeeds, lastMemoryRefreshTime);
}

public static class Builder {

private TreeMap<String, Job> jobs;
private TreeMap<String, DatafeedConfig> datafeeds;
private Instant lastMemoryRefreshTime;

public Builder() {
jobs = new TreeMap<>();
Expand All @@ -263,6 +309,7 @@ public Builder(@Nullable MlMetadata previous) {
} else {
jobs = new TreeMap<>(previous.jobs);
datafeeds = new TreeMap<>(previous.datafeeds);
lastMemoryRefreshTime = previous.lastMemoryRefreshTime;
}
}

Expand Down Expand Up @@ -382,8 +429,18 @@ private Builder putDatafeeds(Collection<DatafeedConfig> datafeeds) {
return this;
}

Builder setLastMemoryRefreshTimeMs(long lastMemoryRefreshTimeMs) {
lastMemoryRefreshTime = Instant.ofEpochMilli(lastMemoryRefreshTimeMs);
return this;
}

public Builder setLastMemoryRefreshTime(Instant lastMemoryRefreshTime) {
this.lastMemoryRefreshTime = lastMemoryRefreshTime;
return this;
}

public MlMetadata build() {
return new MlMetadata(jobs, datafeeds);
return new MlMetadata(jobs, datafeeds, lastMemoryRefreshTime);
}

public void markJobAsDeleting(String jobId, PersistentTasksCustomMetaData tasks, boolean allowDeleteOpenJob) {
Expand Down Expand Up @@ -420,8 +477,6 @@ void checkJobHasNoDatafeed(String jobId) {
}
}



public static MlMetadata getMlMetadata(ClusterState state) {
MlMetadata mlMetadata = (state == null) ? null : state.getMetaData().custom(TYPE);
if (mlMetadata == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ private static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFie
private final Date createTime;
private final Date finishedTime;
private final Date lastDataTime;
// TODO: Remove in 7.0
private final Long establishedModelMemory;
private final AnalysisConfig analysisConfig;
private final AnalysisLimits analysisLimits;
Expand Down Expand Up @@ -439,6 +440,7 @@ public Collection<String> allInputFields() {
* program code and stack.
* @return an estimate of the memory requirement of this job, in bytes
*/
// TODO: remove this method in 7.0
public long estimateMemoryFootprint() {
if (establishedModelMemory != null && establishedModelMemory > 0) {
return establishedModelMemory + PROCESS_MEMORY_OVERHEAD.getBytes();
Expand Down Expand Up @@ -658,6 +660,7 @@ public static class Builder implements Writeable, ToXContentObject {
private Date createTime;
private Date finishedTime;
private Date lastDataTime;
// TODO: remove in 7.0
private Long establishedModelMemory;
private ModelPlotConfig modelPlotConfig;
private Long renormalizationWindowDays;
Expand Down Expand Up @@ -1102,10 +1105,6 @@ private void validateGroups() {
public Job build(Date createTime) {
setCreateTime(createTime);
setJobVersion(Version.CURRENT);
// TODO: Maybe we _could_ accept a value for this supplied at create time - it would
// mean cloned jobs that hadn't been edited much would start with an accurate expected size.
// But on the other hand it would mean jobs that were cloned and then completely changed
// would start with a size that was completely wrong.
setEstablishedModelMemory(null);
return build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ public void testEstimateMemoryFootprint_GivenNoLimitAndNotEstablished() {
builder.setEstablishedModelMemory(0L);
}
assertEquals(ByteSizeUnit.MB.toBytes(AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB)
+ Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint());
+ Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint());
}

public void testEarliestValidTimestamp_GivenEmptyDataCounts() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory;
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory;
import org.elasticsearch.xpack.ml.notifications.Auditor;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction;
Expand Down Expand Up @@ -278,6 +279,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu

private final SetOnce<AutodetectProcessManager> autodetectProcessManager = new SetOnce<>();
private final SetOnce<DatafeedManager> datafeedManager = new SetOnce<>();
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();

public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
Expand Down Expand Up @@ -420,6 +422,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
this.datafeedManager.set(datafeedManager);
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
autodetectProcessManager);
MlMemoryTracker memoryTracker = new MlMemoryTracker(clusterService, threadPool, jobManager, jobResultsProvider);
this.memoryTracker.set(memoryTracker);

// This object's constructor attaches to the license state, so there's no need to retain another reference to it
new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager);
Expand All @@ -438,7 +442,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
jobDataCountsPersister,
datafeedManager,
auditor,
new MlAssignmentNotifier(auditor, clusterService)
new MlAssignmentNotifier(auditor, clusterService),
memoryTracker
);
}

Expand All @@ -449,7 +454,8 @@ public List<PersistentTasksExecutor<?>> getPersistentTasksExecutor(ClusterServic
}

return Arrays.asList(
new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get()),
new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get(),
memoryTracker.get()),
new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(settings, datafeedManager.get())
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import org.elasticsearch.xpack.ml.job.persistence.JobDataDeleter;
import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
import org.elasticsearch.xpack.ml.notifications.Auditor;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;

import java.util.ArrayList;
Expand All @@ -94,6 +95,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction<DeleteJo
private final JobResultsProvider jobResultsProvider;
private final JobConfigProvider jobConfigProvider;
private final DatafeedConfigProvider datafeedConfigProvider;
private final MlMemoryTracker memoryTracker;

/**
* A map of task listeners by job_id.
Expand All @@ -108,7 +110,8 @@ public TransportDeleteJobAction(Settings settings, TransportService transportSer
ThreadPool threadPool, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, PersistentTasksService persistentTasksService,
Client client, Auditor auditor, JobResultsProvider jobResultsProvider,
JobConfigProvider jobConfigProvider, DatafeedConfigProvider datafeedConfigProvider) {
JobConfigProvider jobConfigProvider, DatafeedConfigProvider datafeedConfigProvider,
MlMemoryTracker memoryTracker) {
super(settings, DeleteJobAction.NAME, transportService, clusterService, threadPool, actionFilters,
indexNameExpressionResolver, DeleteJobAction.Request::new);
this.client = client;
Expand All @@ -117,6 +120,7 @@ public TransportDeleteJobAction(Settings settings, TransportService transportSer
this.jobResultsProvider = jobResultsProvider;
this.jobConfigProvider = jobConfigProvider;
this.datafeedConfigProvider = datafeedConfigProvider;
this.memoryTracker = memoryTracker;
this.listenersByJobId = new HashMap<>();
}

Expand Down Expand Up @@ -210,6 +214,7 @@ private void notifyListeners(String jobId, @Nullable AcknowledgedResponse ack, @
private void normalDeleteJob(ParentTaskAssigningClient parentTaskClient, DeleteJobAction.Request request,
ActionListener<AcknowledgedResponse> listener) {
String jobId = request.getJobId();
memoryTracker.removeJob(jobId);
droberts195 marked this conversation as resolved.
Show resolved Hide resolved

// Step 4. When the job has been removed from the cluster state, return a response
// -------
Expand Down
Loading