From 1a27813920ca8a654d209f811c3d3ce4629011d4 Mon Sep 17 00:00:00 2001 From: Kiran Prakash Date: Wed, 7 Aug 2024 10:30:02 -0700 Subject: [PATCH] cancellation related Signed-off-by: Kiran Prakash --- .../AbstractTaskSelectionStrategy.java | 81 +++++ .../cancellation/DefaultTaskCancellation.java | 218 +++++++++++ ...gestRunningTaskFirstSelectionStrategy.java | 29 ++ ...testRunningTaskFirstSelectionStrategy.java | 29 ++ .../cancellation/TaskSelectionStrategy.java | 32 ++ .../wlm/cancellation/package-info.java | 12 + .../DefaultTaskCancellationTests.java | 340 ++++++++++++++++++ ...skFirstStrategySelectionStrategyTests.java | 34 ++ ...skFirstStrategySelectionStrategyTests.java | 34 ++ .../TaskSelectionStrategyTests.java | 121 +++++++ 10 files changed, 930 insertions(+) create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/AbstractTaskSelectionStrategy.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstSelectionStrategy.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstSelectionStrategy.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java create mode 100644 server/src/main/java/org/opensearch/wlm/cancellation/package-info.java create mode 100644 server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstStrategySelectionStrategyTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstStrategySelectionStrategyTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/cancellation/TaskSelectionStrategyTests.java diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/AbstractTaskSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/AbstractTaskSelectionStrategy.java new file mode 100644 index 0000000000000..4f592392a3d63 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/AbstractTaskSelectionStrategy.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Represents an abstract task selection strategy. + * This class implements the TaskSelectionStrategy interface and provides a method to select tasks for cancellation based on a sorting condition. + * The specific sorting condition depends on the implementation. + */ +public abstract class AbstractTaskSelectionStrategy implements TaskSelectionStrategy { + + /** + * Returns a comparator that defines the sorting condition for tasks. + * The specific sorting condition depends on the implementation. + * + * @return The comparator + */ + public abstract Comparator sortingCondition(); + + /** + * Selects tasks for cancellation based on the provided limit and resource type. + * The tasks are sorted based on the sorting condition and then selected until the accumulated resource usage reaches the limit. + * + * @param tasks The list of tasks from which to select + * @param limit The limit on the accumulated resource usage + * @param resourceType The type of resource to consider + * @return The list of selected tasks + * @throws IllegalArgumentException If the limit is less than zero + */ + @Override + public List selectTasksForCancellation(List tasks, long limit, ResourceType resourceType) { + if (limit < 0) { + throw new IllegalArgumentException("reduceBy has to be greater than zero"); + } + if (limit == 0) { + return Collections.emptyList(); + } + + List sortedTasks = tasks.stream().sorted(sortingCondition()).collect(Collectors.toList()); + + List selectedTasks = new ArrayList<>(); + long accumulated = 0; + + for (Task task : sortedTasks) { + if (task instanceof CancellableTask) { + selectedTasks.add(createTaskCancellation((CancellableTask) task)); + accumulated += resourceType.getResourceUsage(task); + if (accumulated >= limit) { + break; + } + } + } + return selectedTasks; + } + + private TaskCancellation createTaskCancellation(CancellableTask task) { + // TODO add correct reason and callbacks + return new TaskCancellation(task, List.of(new TaskCancellation.Reason("limits exceeded", 5)), List.of(this::callbackOnCancel)); + } + + private void callbackOnCancel() { + // todo Implement callback logic here mostly used for Stats + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java new file mode 100644 index 0000000000000..d932d21e4affe --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java @@ -0,0 +1,218 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.cluster.metadata.QueryGroup; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.monitor.process.ProcessProbe; +import org.opensearch.search.ResourceType; +import org.opensearch.search.backpressure.settings.NodeDuressSettings; +import org.opensearch.search.backpressure.trackers.NodeDuressTrackers; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; + +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.wlm.tracker.QueryGroupResourceUsageTrackerService.TRACKED_RESOURCES; + +/** + * Manages the cancellation of tasks enforced by QueryGroup thresholds on resource usage criteria. + * This class utilizes a strategy pattern through {@link TaskSelectionStrategy} to identify tasks that exceed + * predefined resource usage limits and are therefore eligible for cancellation. + * + *

The cancellation process is initiated by evaluating the resource usage of each QueryGroup against its + * resource limits. Tasks that contribute to exceeding these limits are selected for cancellation based on the + * implemented task selection strategy.

+ * + *

Instances of this class are configured with a map linking QueryGroup IDs to their corresponding resource usage + * views, a set of active QueryGroups, and a task selection strategy. These components collectively facilitate the + * identification and cancellation of tasks that threaten to breach QueryGroup resource limits.

+ * + * @see TaskSelectionStrategy + * @see QueryGroup + * @see ResourceType + */ +public class DefaultTaskCancellation { + private static final long HEAP_SIZE_BYTES = JvmStats.jvmStats().getMem().getHeapMax().getBytes(); + + protected final TaskSelectionStrategy taskSelectionStrategy; + // a map of QueryGroupId to its corresponding QueryGroupLevelResourceUsageView object + protected final Map queryGroupLevelResourceUsageViews; + protected final Set activeQueryGroups; + protected NodeDuressTrackers nodeDuressTrackers; + + public DefaultTaskCancellation( + TaskSelectionStrategy taskSelectionStrategy, + Map queryGroupLevelResourceUsageViews, + Set activeQueryGroups, + Settings settings, + ClusterSettings clusterSettings + ) { + this.taskSelectionStrategy = taskSelectionStrategy; + this.queryGroupLevelResourceUsageViews = queryGroupLevelResourceUsageViews; + this.activeQueryGroups = activeQueryGroups; + this.nodeDuressTrackers = setupNodeDuressTracker(settings, clusterSettings); + } + + /** + * Cancel tasks based on the implemented strategy. + */ + public final void cancelTasks() { + cancelTasksForMode(QueryGroup.ResiliencyMode.ENFORCED); + + if (nodeDuressTrackers.isNodeInDuress()) { + cancelTasksForMode(QueryGroup.ResiliencyMode.SOFT); + } + } + + private void cancelTasksForMode(QueryGroup.ResiliencyMode resiliencyMode) { + List cancellableTasks = getAllCancellableTasksFrom(resiliencyMode); + for (TaskCancellation taskCancellation : cancellableTasks) { + taskCancellation.cancel(); + } + } + + /** + * Get all cancellable tasks from the QueryGroups. + * + * @return List of tasks that can be cancelled + */ + protected List getAllCancellableTasksFrom(QueryGroup.ResiliencyMode resiliencyMode) { + return getQueryGroupsToCancelFrom(resiliencyMode).stream() + .flatMap(queryGroup -> getCancellableTasksFrom(queryGroup).stream()) + .collect(Collectors.toList()); + } + + /** + * returns the list of QueryGroups breaching their resource limits. + * + * @return List of QueryGroups + */ + private List getQueryGroupsToCancelFrom(QueryGroup.ResiliencyMode resiliencyMode) { + final List queryGroupsToCancelFrom = new ArrayList<>(); + + for (QueryGroup queryGroup : this.activeQueryGroups) { + if (queryGroup.getResiliencyMode() != resiliencyMode) { + continue; + } + Map queryGroupResourceUsage = queryGroupLevelResourceUsageViews.get(queryGroup.get_id()) + .getResourceUsageData(); + + for (ResourceType resourceType : TRACKED_RESOURCES) { + if (queryGroup.getResourceLimits().containsKey(resourceType) && queryGroupResourceUsage.containsKey(resourceType)) { + Double resourceLimit = (Double) queryGroup.getResourceLimits().get(resourceType); + Long resourceUsage = queryGroupResourceUsage.get(resourceType); + + if (isBreachingThreshold(resourceType, resourceLimit, resourceUsage)) { + queryGroupsToCancelFrom.add(queryGroup); + break; + } + } + } + } + + return queryGroupsToCancelFrom; + } + + /** + * Get cancellable tasks from a specific queryGroup. + * + * @param queryGroup The QueryGroup from which to get cancellable tasks + * @return List of tasks that can be cancelled + */ + protected List getCancellableTasksFrom(QueryGroup queryGroup) { + return TRACKED_RESOURCES.stream() + .filter(resourceType -> shouldCancelTasks(queryGroup, resourceType)) + .flatMap(resourceType -> getTaskCancellations(queryGroup, resourceType).stream()) + .collect(Collectors.toList()); + } + + private boolean shouldCancelTasks(QueryGroup queryGroup, ResourceType resourceType) { + long reduceBy = getReduceBy(queryGroup, resourceType); + return reduceBy > 0; + } + + private List getTaskCancellations(QueryGroup queryGroup, ResourceType resourceType) { + return taskSelectionStrategy.selectTasksForCancellation( + // get the active tasks in the query group + queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks(), + getReduceBy(queryGroup, resourceType), + resourceType + ); + } + + private long getReduceBy(QueryGroup queryGroup, ResourceType resourceType) { + if (queryGroup.getResourceLimits().get(resourceType) == null) { + return 0; + } + Double threshold = (Double) queryGroup.getResourceLimits().get(resourceType); + return getResourceUsage(queryGroup, resourceType) - convertThresholdIntoLong(resourceType, threshold); + } + + private Long convertThresholdIntoLong(ResourceType resourceType, Double resourceThresholdInPercentage) { + Long threshold = null; + if (resourceType == ResourceType.MEMORY) { + // Check if resource usage is breaching the threshold + threshold = (long) (resourceThresholdInPercentage * HEAP_SIZE_BYTES); + } else if (resourceType == ResourceType.CPU) { + // Get the total CPU time of the process in milliseconds + long cpuTotalTimeInMillis = ProcessProbe.getInstance().getProcessCpuTotalTime(); + // Check if resource usage is breaching the threshold + threshold = (long) (resourceThresholdInPercentage * cpuTotalTimeInMillis); + } + return threshold; + } + + private Long getResourceUsage(QueryGroup queryGroup, ResourceType resourceType) { + if (!queryGroupLevelResourceUsageViews.containsKey(queryGroup.get_id())) { + return 0L; + } + return queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getResourceUsageData().get(resourceType); + } + + private boolean isBreachingThreshold(ResourceType resourceType, Double resourceThresholdInPercentage, long resourceUsage) { + if (resourceType == ResourceType.MEMORY) { + // Check if resource usage is breaching the threshold + return resourceUsage > convertThresholdIntoLong(resourceType, resourceThresholdInPercentage); + } + // Resource types should be CPU, resourceUsage is in nanoseconds, convert to milliseconds + long resourceUsageInMillis = resourceUsage / 1_000_000; + // Check if resource usage is breaching the threshold + return resourceUsageInMillis > convertThresholdIntoLong(resourceType, resourceThresholdInPercentage); + } + + private NodeDuressTrackers setupNodeDuressTracker(Settings settings, ClusterSettings clusterSettings) { + NodeDuressSettings nodeDuressSettings = new NodeDuressSettings(settings, clusterSettings); + return new NodeDuressTrackers(new EnumMap<>(ResourceType.class) { + { + put( + ResourceType.CPU, + new NodeDuressTrackers.NodeDuressTracker( + () -> ProcessProbe.getInstance().getProcessCpuPercent() / 100.0 >= nodeDuressSettings.getCpuThreshold(), + nodeDuressSettings::getNumSuccessiveBreaches + ) + ); + put( + ResourceType.MEMORY, + new NodeDuressTrackers.NodeDuressTracker( + () -> JvmStats.jvmStats().getMem().getHeapUsedPercent() / 100.0 >= nodeDuressSettings.getHeapThreshold(), + nodeDuressSettings::getNumSuccessiveBreaches + ) + ); + } + }); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstSelectionStrategy.java new file mode 100644 index 0000000000000..d36d55b25bb4a --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstSelectionStrategy.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; + +import java.util.Comparator; + +/** + * Represents a task selection strategy that prioritizes the longest running tasks first. + */ +public class LongestRunningTaskFirstSelectionStrategy extends AbstractTaskSelectionStrategy { + + /** + * Returns a comparator that sorts tasks based on their start time in descending order. + * + * @return The comparator + */ + @Override + public Comparator sortingCondition() { + return Comparator.comparingLong(Task::getStartTime); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstSelectionStrategy.java new file mode 100644 index 0000000000000..1e8e75b291d05 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstSelectionStrategy.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; + +import java.util.Comparator; + +/** + * Represents a task selection strategy that prioritizes the shortest running tasks first. + */ +public class ShortestRunningTaskFirstSelectionStrategy extends AbstractTaskSelectionStrategy { + + /** + * Returns a comparator that sorts tasks based on their start time in ascending order. + * + * @return The comparator + */ + @Override + public Comparator sortingCondition() { + return Comparator.comparingLong(Task::getStartTime).reversed(); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java new file mode 100644 index 0000000000000..72161671186f2 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/TaskSelectionStrategy.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; + +import java.util.List; + +/** + * Interface for strategies to select tasks for cancellation. + * Implementations of this interface define how tasks are selected for cancellation based on resource usage. + */ +public interface TaskSelectionStrategy { + /** + * Determines which tasks should be cancelled based on the provided criteria. + * + * @param tasks List of tasks available for cancellation. + * @param limit The amount of tasks to select whose resources reach this limit + * @param resourceType The type of resource that needs to be reduced, guiding the selection process. + * + * @return List of tasks that should be cancelled. + */ + List selectTasksForCancellation(List tasks, long limit, ResourceType resourceType); +} diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java b/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java new file mode 100644 index 0000000000000..9618d22c9d5e2 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/cancellation/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * QueryGroup resource cancellation artifacts + */ +package org.opensearch.wlm.cancellation; diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java new file mode 100644 index 0000000000000..0c8f186ed425b --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java @@ -0,0 +1,340 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchTask; +import org.opensearch.cluster.metadata.QueryGroup; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.search.ResourceType; +import org.opensearch.search.backpressure.trackers.NodeDuressTrackers; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.junit.Before; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class DefaultTaskCancellationTests extends OpenSearchTestCase { + private static final String queryGroupId1 = "queryGroup1"; + private static final String queryGroupId2 = "queryGroup2"; + + private static class TestTaskCancellationImpl extends DefaultTaskCancellation { + + public TestTaskCancellationImpl( + TaskSelectionStrategy taskSelectionStrategy, + Map queryGroupLevelViews, + Set activeQueryGroups + ) { + super( + taskSelectionStrategy, + queryGroupLevelViews, + activeQueryGroups, + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + } + } + + private Map queryGroupLevelViews; + private Set activeQueryGroups; + private DefaultTaskCancellation taskCancellation; + + @Before + public void setup() { + queryGroupLevelViews = new HashMap<>(); + activeQueryGroups = new HashSet<>(); + taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + } + + public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThreshold() { + ResourceType resourceType = ResourceType.CPU; + long usage = 100_000_000L; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + + List cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup1); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThresholdForMemory() { + ResourceType resourceType = ResourceType.MEMORY; + long usage = 900_000_000_000L; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_returnsNoTasksWhenNotBreachingThreshold() { + ResourceType resourceType = ResourceType.CPU; + long usage = 500L; + Double threshold = 0.9; + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup1); + assertTrue(cancellableTasksFrom.isEmpty()); + } + + public void testGetCancellableTasksFrom_filtersQueryGroupCorrectly() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.SOFT); + assertEquals(0, cancellableTasksFrom.size()); + } + + public void testCancelTasks_cancelsGivenTasks() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + taskCancellation.cancelTasks(); + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + } + + public void testCancelTasks_cancelsGivenTasks_WhenNodeInDuress() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroup queryGroup2 = new QueryGroup( + "testQueryGroup", + queryGroupId2, + QueryGroup.ResiliencyMode.SOFT, + Map.of(resourceType, threshold), + 1L + ); + + queryGroupLevelViews.put(queryGroupId1, createResourceUsageViewMock(resourceType, usage)); + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + when(mockView.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(5678), getRandomSearchTask(8765))); + queryGroupLevelViews.put(queryGroupId2, mockView); + Collections.addAll(activeQueryGroups, queryGroup1, queryGroup2); + + TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl( + new TaskSelectionStrategyTests.TestTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups + ); + + NodeDuressTrackers mock = mock(NodeDuressTrackers.class); + when(mock.isNodeInDuress()).thenReturn(true); + taskCancellation.nodeDuressTrackers = mock; + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + List cancellableTasksFrom1 = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.SOFT); + assertEquals(2, cancellableTasksFrom1.size()); + assertEquals(5678, cancellableTasksFrom1.get(0).getTask().getId()); + assertEquals(8765, cancellableTasksFrom1.get(1).getTask().getId()); + + taskCancellation.cancelTasks(); + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + assertTrue(cancellableTasksFrom1.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom1.get(1).getTask().isCancelled()); + } + + public void testGetAllCancellableTasks_ReturnsNoTasksFromWhenNotBreachingThresholds() { + ResourceType resourceType = ResourceType.CPU; + long usage = 1L; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List allCancellableTasks = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertTrue(allCancellableTasks.isEmpty()); + } + + public void testGetAllCancellableTasks_ReturnsTasksFromWhenBreachingThresholds() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + + List allCancellableTasks = taskCancellation.getAllCancellableTasksFrom(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, allCancellableTasks.size()); + assertEquals(1234, allCancellableTasks.get(0).getTask().getId()); + assertEquals(4321, allCancellableTasks.get(1).getTask().getId()); + } + + public void testGetCancellableTasksFrom_doesNotReturnTasksWhenQueryGroupIdNotFound() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000_000L; + Double threshold = 0.01; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + QueryGroup queryGroup2 = new QueryGroup( + "testQueryGroup", + queryGroupId2, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + + queryGroupLevelViews.put(queryGroupId1, mockView); + activeQueryGroups.add(queryGroup1); + activeQueryGroups.add(queryGroup2); + + List cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup2); + assertEquals(0, cancellableTasksFrom.size()); + } + + private QueryGroupLevelResourceUsageView createResourceUsageViewMock(ResourceType resourceType, Long usage) { + QueryGroupLevelResourceUsageView mockView = mock(QueryGroupLevelResourceUsageView.class); + when(mockView.getResourceUsageData()).thenReturn(Collections.singletonMap(resourceType, usage)); + when(mockView.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(1234), getRandomSearchTask(4321))); + return mockView; + } + + private Task getRandomSearchTask(long id) { + return new SearchTask( + id, + "transport", + SearchAction.NAME, + () -> "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstStrategySelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstStrategySelectionStrategyTests.java new file mode 100644 index 0000000000000..ad76a5021b175 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/LongestRunningTaskFirstStrategySelectionStrategyTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class LongestRunningTaskFirstStrategySelectionStrategyTests extends OpenSearchTestCase { + public void testSortingCondition() { + Task task1 = mock(Task.class); + Task task2 = mock(Task.class); + Task task3 = mock(Task.class); + when(task1.getStartTime()).thenReturn(100L); + when(task2.getStartTime()).thenReturn(200L); + when(task3.getStartTime()).thenReturn(300L); + + List tasks = Arrays.asList(task2, task1, task3); + tasks.sort(new LongestRunningTaskFirstSelectionStrategy().sortingCondition()); + + assertEquals(Arrays.asList(task1, task2, task3), tasks); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstStrategySelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstStrategySelectionStrategyTests.java new file mode 100644 index 0000000000000..3c07df09f6f5e --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/ShortestRunningTaskFirstStrategySelectionStrategyTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShortestRunningTaskFirstStrategySelectionStrategyTests extends OpenSearchTestCase { + public void testSortingCondition() { + Task task1 = mock(Task.class); + Task task2 = mock(Task.class); + Task task3 = mock(Task.class); + when(task1.getStartTime()).thenReturn(100L); + when(task2.getStartTime()).thenReturn(200L); + when(task3.getStartTime()).thenReturn(300L); + + List tasks = Arrays.asList(task1, task3, task2); + tasks.sort(new ShortestRunningTaskFirstSelectionStrategy().sortingCondition()); + + assertEquals(Arrays.asList(task3, task2, task1), tasks); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/TaskSelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/TaskSelectionStrategyTests.java new file mode 100644 index 0000000000000..43ccbd0920068 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/cancellation/TaskSelectionStrategyTests.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.cancellation; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchTask; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +public class TaskSelectionStrategyTests extends OpenSearchTestCase { + + public static class TestTaskSelectionStrategy extends AbstractTaskSelectionStrategy { + @Override + public Comparator sortingCondition() { + return Comparator.comparingLong(Task::getId); + } + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsGreaterThanZero() { + TaskSelectionStrategy testTaskSelectionStrategy = new TestTaskSelectionStrategy(); + long threshold = 100L; + long reduceBy = 50L; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(threshold); + + List selectedTasks = testTaskSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); + assertFalse(selectedTasks.isEmpty()); + assertTrue(tasksUsageMeetsThreshold(selectedTasks, reduceBy)); + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsLesserThanZero() { + TaskSelectionStrategy testTaskSelectionStrategy = new TestTaskSelectionStrategy(); + long threshold = 100L; + long reduceBy = -50L; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(threshold); + + try { + testTaskSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); + } catch (Exception e) { + assertTrue(e instanceof IllegalArgumentException); + assertEquals("reduceBy has to be greater than zero", e.getMessage()); + } + } + + public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsEqualToZero() { + TaskSelectionStrategy testTaskSelectionStrategy = new TestTaskSelectionStrategy(); + long threshold = 100L; + long reduceBy = 0; + ResourceType resourceType = ResourceType.MEMORY; + List tasks = getListOfTasks(threshold); + + List selectedTasks = testTaskSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); + assertTrue(selectedTasks.isEmpty()); + } + + private boolean tasksUsageMeetsThreshold(List selectedTasks, long threshold) { + long memory = 0; + for (TaskCancellation task : selectedTasks) { + memory += task.getTask().getTotalResourceUtilization(ResourceStats.MEMORY); + if (memory > threshold) { + return true; + } + } + return false; + } + + private List getListOfTasks(long totalMemory) { + List tasks = new ArrayList<>(); + + while (totalMemory > 0) { + long id = randomLong(); + final Task task = getRandomSearchTask(id); + long initial_memory = randomLongBetween(1, 100); + + ResourceUsageMetric[] initialTaskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, initial_memory) }; + task.startThreadResourceTracking(id, ResourceStatsType.WORKER_STATS, initialTaskResourceMetrics); + + long memory = initial_memory + randomLongBetween(1, 10000); + + totalMemory -= memory - initial_memory; + + ResourceUsageMetric[] taskResourceMetrics = new ResourceUsageMetric[] { + new ResourceUsageMetric(ResourceStats.MEMORY, memory), }; + task.updateThreadResourceStats(id, ResourceStatsType.WORKER_STATS, taskResourceMetrics); + task.stopThreadResourceTracking(id, ResourceStatsType.WORKER_STATS); + tasks.add(task); + } + + return tasks; + } + + private Task getRandomSearchTask(long id) { + return new SearchTask( + id, + "transport", + SearchAction.NAME, + () -> "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + } +}