diff --git a/core/trino-main/src/main/java/io/trino/memory/TotalReservationOnBlockedNodesLowMemoryKiller.java b/core/trino-main/src/main/java/io/trino/memory/TotalReservationOnBlockedNodesLowMemoryKiller.java index a2c9396f72b3..8d1797a84449 100644 --- a/core/trino-main/src/main/java/io/trino/memory/TotalReservationOnBlockedNodesLowMemoryKiller.java +++ b/core/trino-main/src/main/java/io/trino/memory/TotalReservationOnBlockedNodesLowMemoryKiller.java @@ -14,6 +14,9 @@ package io.trino.memory; +import com.google.common.collect.ImmutableSet; +import io.trino.TaskMemoryInfo; +import io.trino.execution.TaskId; import io.trino.spi.QueryId; import io.trino.spi.memory.MemoryPoolInfo; @@ -21,7 +24,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; +import static java.util.Comparator.comparing; import static java.util.Comparator.comparingLong; public class TotalReservationOnBlockedNodesLowMemoryKiller @@ -29,6 +34,39 @@ public class TotalReservationOnBlockedNodesLowMemoryKiller { @Override public Optional chooseQueryToKill(List runningQueries, List nodes) + { + Optional killTarget = chooseTasksToKill(nodes); + if (killTarget.isEmpty()) { + killTarget = chooseWholeQueryToKill(nodes); + } + return killTarget; + } + + private Optional chooseTasksToKill(List nodes) + { + ImmutableSet.Builder tasksToKillBuilder = ImmutableSet.builder(); + for (MemoryInfo node : nodes) { + MemoryPoolInfo memoryPool = node.getPool(); + if (memoryPool == null) { + continue; + } + if (memoryPool.getFreeBytes() + memoryPool.getReservedRevocableBytes() > 0) { + continue; + } + + node.getTasksMemoryInfo().values().stream() + .max(comparing(TaskMemoryInfo::getMemoryReservation)) + .map(TaskMemoryInfo::getTaskId) + .ifPresent(tasksToKillBuilder::add); + } + Set tasksToKill = tasksToKillBuilder.build(); + if (tasksToKill.isEmpty()) { + return Optional.empty(); + } + return Optional.of(KillTarget.selectedTasks(tasksToKill)); + } + + private Optional chooseWholeQueryToKill(List nodes) { Map memoryReservationOnBlockedNodes = new HashMap<>(); for (MemoryInfo node : nodes) { diff --git a/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java b/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java index 85762cf242d4..eee2b32b2f21 100644 --- a/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java +++ b/core/trino-main/src/test/java/io/trino/memory/LowMemoryKillerTestingUtils.java @@ -15,8 +15,12 @@ package io.trino.memory; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; +import io.trino.TaskMemoryInfo; import io.trino.client.NodeVersion; +import io.trino.execution.TaskId; import io.trino.metadata.InternalNode; import io.trino.spi.QueryId; import io.trino.spi.memory.MemoryPoolInfo; @@ -31,6 +35,11 @@ public final class LowMemoryKillerTestingUtils private LowMemoryKillerTestingUtils() {} static List toNodeMemoryInfoList(long memoryPoolMaxBytes, Map> queries) + { + return toNodeMemoryInfoList(memoryPoolMaxBytes, queries, ImmutableMap.of()); + } + + static List toNodeMemoryInfoList(long memoryPoolMaxBytes, Map> queries, Map>> tasks) { Map nodeReservations = new HashMap<>(); @@ -58,7 +67,27 @@ static List toNodeMemoryInfoList(long memoryPoolMaxBytes, Map tasksMemoryInfoForNode(String nodeIdentifier, Map>> tasks) + { + ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); + for (Map.Entry>> queryNodesEntry : tasks.entrySet()) { + QueryId query = QueryId.valueOf(queryNodesEntry.getKey()); + for (Map.Entry> nodeTasksEntry : queryNodesEntry.getValue().entrySet()) { + if (!nodeIdentifier.equals(nodeTasksEntry.getKey())) { + continue; + } + + for (Map.Entry taskReservationEntry : nodeTasksEntry.getValue().entrySet()) { + TaskId taskId = TaskId.valueOf(taskReservationEntry.getKey()); + long taskReservation = taskReservationEntry.getValue(); + result.put(query, new TaskMemoryInfo(taskId, taskReservation)); + } + } } return result.build(); } diff --git a/core/trino-main/src/test/java/io/trino/memory/TestTotalReservationOnBlockedNodesLowMemoryKiller.java b/core/trino-main/src/test/java/io/trino/memory/TestTotalReservationOnBlockedNodesLowMemoryKiller.java index dc5ed53f0c13..c280f26e5a15 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestTotalReservationOnBlockedNodesLowMemoryKiller.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestTotalReservationOnBlockedNodesLowMemoryKiller.java @@ -15,6 +15,8 @@ package io.trino.memory; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.execution.TaskId; import io.trino.spi.QueryId; import org.testng.annotations.Test; @@ -36,6 +38,7 @@ public void testMemoryPoolHasNoReservation() Map> queries = ImmutableMap.>builder() .put("q_1", ImmutableMap.of("n1", 0L, "n2", 0L, "n3", 0L, "n4", 0L, "n5", 0L)) .buildOrThrow(); + assertEquals( lowMemoryKiller.chooseQueryToKill( toQueryMemoryInfoList(queries), @@ -75,4 +78,58 @@ public void testSkewedQuery() toNodeMemoryInfoList(memoryPool, queries)), Optional.of(KillTarget.wholeQuery(new QueryId("q_1")))); } + + @Test + public void testPreferKillingTasks() + { + int memoryPool = 12; + Map> queries = ImmutableMap.>builder() + .put("q_1", ImmutableMap.of("n1", 0L, "n2", 8L, "n3", 0L, "n4", 0L, "n5", 0L)) + .put("q_2", ImmutableMap.of("n1", 3L, "n2", 5L, "n3", 2L, "n4", 4L, "n5", 0L)) + .put("q_3", ImmutableMap.of("n1", 0L, "n2", 0L, "n3", 11L, "n4", 0L, "n5", 0L)) + .buildOrThrow(); + + Map>> tasks = ImmutableMap.>>builder() + .put("q_2", ImmutableMap.of( + "n1", ImmutableMap.of("t1", 1L, "t2", 3L), + "n2", ImmutableMap.of("t3", 3L, "t4", 1L, "t5", 1L), + "n3", ImmutableMap.of("t6", 2L), + "n4", ImmutableMap.of("t7", 2L, "t8", 2L), + "n5", ImmutableMap.of()) + ).buildOrThrow(); + + assertEquals( + lowMemoryKiller.chooseQueryToKill( + toQueryMemoryInfoList(queries), + toNodeMemoryInfoList(memoryPool, queries, tasks)), + Optional.of(KillTarget.selectedTasks(ImmutableSet.of(TaskId.valueOf("t3"), TaskId.valueOf("t6"))))); + } + + @Test + public void testKillsBiggestTasks() + { + int memoryPool = 12; + Map> queries = ImmutableMap.>builder() + .put("q_1", ImmutableMap.of("n1", 0L, "n2", 8L, "n3", 0L, "n4", 0L, "n5", 0L)) + .put("q_2", ImmutableMap.of("n1", 3L, "n2", 5L, "n3", 2L, "n4", 4L, "n5", 0L)) + .put("q_3", ImmutableMap.of("n1", 0L, "n2", 0L, "n3", 11L, "n4", 0L, "n5", 0L)) + .buildOrThrow(); + + Map>> tasks = ImmutableMap.>>builder() + .put("q_1", ImmutableMap.of( + "n2", ImmutableMap.of("t1_1", 8L))) + .put("q_2", ImmutableMap.of( + "n1", ImmutableMap.of("t2_1", 1L, "t2_2", 3L), + "n2", ImmutableMap.of("t2_3", 3L, "t2_4", 1L, "t2_5", 1L), + "n3", ImmutableMap.of("t2_6", 2L), + "n4", ImmutableMap.of("t2_7", 2L, "t2_8", 2L), + "n5", ImmutableMap.of())) + .buildOrThrow(); + + assertEquals( + lowMemoryKiller.chooseQueryToKill( + toQueryMemoryInfoList(queries), + toNodeMemoryInfoList(memoryPool, queries, tasks)), + Optional.of(KillTarget.selectedTasks(ImmutableSet.of(TaskId.valueOf("t1_1"), TaskId.valueOf("t2_6"))))); + } }