diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 99aefd23ed06..4d16ef25e6fe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -25,6 +25,7 @@ import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.QueryPreparer.PreparedQuery; import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.SqlQueryScheduler; @@ -99,6 +100,7 @@ public class SqlQueryExecution private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final NodeAllocatorService nodeAllocatorService; private final List planOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; @@ -132,6 +134,7 @@ private SqlQueryExecution( SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + NodeAllocatorService nodeAllocatorService, List planOptimizers, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory, @@ -159,6 +162,7 @@ private SqlQueryExecution( this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -497,6 +501,7 @@ private void planDistribution(PlanRoot plan) plan.getRoot(), nodePartitioningManager, nodeScheduler, + nodeAllocatorService, remoteTaskFactory, plan.isSummarizeTaskInfos(), scheduleSplitBatchSize, @@ -698,6 +703,7 @@ public static class SqlQueryExecutionFactory private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final NodeAllocatorService nodeAllocatorService; private final List planOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; @@ -724,6 +730,7 @@ public static class SqlQueryExecutionFactory SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + NodeAllocatorService nodeAllocatorService, PlanOptimizersFactory planOptimizersFactory, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory, @@ -751,6 +758,7 @@ public static class SqlQueryExecutionFactory this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -790,6 +798,7 @@ public QueryExecution createQueryExecution( splitSourceFactory, nodePartitioningManager, nodeScheduler, + nodeAllocatorService, planOptimizers, planFragmenter, remoteTaskFactory, diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java deleted file mode 100644 index 5507285011a8..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import io.trino.Session; -import io.trino.connector.CatalogName; -import io.trino.metadata.InternalNode; -import io.trino.spi.TrinoException; - -import javax.annotation.concurrent.GuardedBy; - -import java.util.HashMap; -import java.util.IdentityHashMap; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.util.concurrent.Futures.immediateFailedFuture; -import static com.google.common.util.concurrent.Futures.immediateFuture; -import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; -import static java.util.Comparator.comparing; -import static java.util.Objects.requireNonNull; - -public class FixedCountNodeAllocator - implements NodeAllocator -{ - private final NodeScheduler nodeScheduler; - - private final Session session; - private final int maximumAllocationsPerNode; - - @GuardedBy("this") - private final Map, NodeSelector> nodeSelectorCache = new HashMap<>(); - - @GuardedBy("this") - private final Map allocationCountMap = new HashMap<>(); - - @GuardedBy("this") - private final LinkedList pendingAcquires = new LinkedList<>(); - - public FixedCountNodeAllocator( - NodeScheduler nodeScheduler, - Session session, - int maximumAllocationsPerNode) - { - this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); - this.session = requireNonNull(session, "session is null"); - this.maximumAllocationsPerNode = maximumAllocationsPerNode; - } - - @Override - public synchronized ListenableFuture acquire(NodeRequirements requirements) - { - try { - Optional node = tryAcquireNode(requirements); - if (node.isPresent()) { - return immediateFuture(node.get()); - } - } - catch (RuntimeException e) { - return immediateFailedFuture(e); - } - - SettableFuture future = SettableFuture.create(); - PendingAcquire pendingAcquire = new PendingAcquire(requirements, future); - pendingAcquires.add(pendingAcquire); - - return future; - } - - @Override - public void release(InternalNode node) - { - releaseNodeInternal(node); - processPendingAcquires(); - } - - @Override - public void updateNodes() - { - processPendingAcquires(); - } - - private synchronized Optional tryAcquireNode(NodeRequirements requirements) - { - NodeSelector nodeSelector = nodeSelectorCache.computeIfAbsent(requirements.getCatalogName(), catalogName -> nodeScheduler.createNodeSelector(session, catalogName)); - - List nodes = nodeSelector.allNodes(); - if (nodes.isEmpty()) { - throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); - } - - List nodesMatchingRequirements = nodes.stream() - .filter(node -> requirements.getAddresses().isEmpty() || requirements.getAddresses().contains(node.getHostAndPort())) - .collect(toImmutableList()); - - if (nodesMatchingRequirements.isEmpty()) { - throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); - } - - Optional selectedNode = nodesMatchingRequirements.stream() - .filter(node -> allocationCountMap.getOrDefault(node, 0) < maximumAllocationsPerNode) - .min(comparing(node -> allocationCountMap.getOrDefault(node, 0))); - - if (selectedNode.isEmpty()) { - return Optional.empty(); - } - - allocationCountMap.compute(selectedNode.get(), (key, value) -> value == null ? 1 : value + 1); - return selectedNode; - } - - private synchronized void releaseNodeInternal(InternalNode node) - { - int allocationCount = allocationCountMap.compute(node, (key, value) -> value == null ? 0 : value - 1); - checkState(allocationCount >= 0, "allocation count for node %s is expected to be greater than or equal to zero: %s", node, allocationCount); - } - - private void processPendingAcquires() - { - verify(!Thread.holdsLock(this)); - - IdentityHashMap assignedNodes = new IdentityHashMap<>(); - IdentityHashMap failures = new IdentityHashMap<>(); - synchronized (this) { - Iterator iterator = pendingAcquires.iterator(); - while (iterator.hasNext()) { - PendingAcquire pendingAcquire = iterator.next(); - if (pendingAcquire.getFuture().isCancelled()) { - iterator.remove(); - continue; - } - try { - Optional node = tryAcquireNode(pendingAcquire.getNodeRequirements()); - if (node.isPresent()) { - iterator.remove(); - assignedNodes.put(pendingAcquire, node.get()); - } - } - catch (RuntimeException e) { - iterator.remove(); - failures.put(pendingAcquire, e); - } - } - } - - assignedNodes.forEach((pendingAcquire, node) -> { - SettableFuture future = pendingAcquire.getFuture(); - future.set(node); - if (future.isCancelled()) { - releaseNodeInternal(node); - } - }); - - failures.forEach((pendingAcquire, failure) -> { - SettableFuture future = pendingAcquire.getFuture(); - future.setException(failure); - }); - } - - @Override - public synchronized void close() - { - } - - private static class PendingAcquire - { - private final NodeRequirements nodeRequirements; - private final SettableFuture future; - - private PendingAcquire(NodeRequirements nodeRequirements, SettableFuture future) - { - this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); - this.future = requireNonNull(future, "future is null"); - } - - public NodeRequirements getNodeRequirements() - { - return nodeRequirements; - } - - public SettableFuture getFuture() - { - return future; - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java new file mode 100644 index 000000000000..2faf06de4b40 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java @@ -0,0 +1,287 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.log.Logger; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.metadata.InternalNode; +import io.trino.spi.TrinoException; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Sets.newConcurrentHashSet; +import static com.google.common.util.concurrent.Futures.immediateFailedFuture; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; +import static java.util.Comparator.comparing; +import static java.util.Objects.requireNonNull; + +/** + * A simplistic node allocation service which only limits number of allocations per node within each + * {@link FixedCountNodeAllocator} instance. Each allocator will allow each node to be acquired up to {@link FixedCountNodeAllocatorService#MAXIMUM_ALLOCATIONS_PER_NODE} + * times at the same time. + */ +@ThreadSafe +public class FixedCountNodeAllocatorService + implements NodeAllocatorService +{ + private static final Logger log = Logger.get(FixedCountNodeAllocatorService.class); + + // Single FixedCountNodeAllocator will allow for at most MAXIMUM_ALLOCATIONS_PER_NODE. + // If we reach this state subsequent calls to acquire will return blocked lease. + private static final int MAXIMUM_ALLOCATIONS_PER_NODE = 1; // TODO make configurable? + + private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1, daemonThreadsNamed("fixed-count-node-allocator")); + private final NodeScheduler nodeScheduler; + + private final Set allocators = newConcurrentHashSet(); + private final AtomicBoolean started = new AtomicBoolean(); + + @Inject + public FixedCountNodeAllocatorService(NodeScheduler nodeScheduler) + { + this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + } + + @PostConstruct + public void start() + { + if (!started.compareAndSet(false, true)) { + // already started + return; + } + executor.scheduleWithFixedDelay(() -> { + try { + updateNodes(); + } + catch (Throwable e) { + // ignore to avoid getting unscheduled + log.warn(e, "Error updating nodes"); + } + }, 5, 5, TimeUnit.SECONDS); + } + + @PreDestroy + public void stop() + { + executor.shutdownNow(); + } + + @VisibleForTesting + void updateNodes() + { + allocators.forEach(FixedCountNodeAllocator::updateNodes); + } + + @Override + public NodeAllocator getNodeAllocator(Session session) + { + requireNonNull(session, "session is null"); + return getNodeAllocator(session, MAXIMUM_ALLOCATIONS_PER_NODE); + } + + @VisibleForTesting + NodeAllocator getNodeAllocator(Session session, int maximumAllocationsPerNode) + { + FixedCountNodeAllocator allocator = new FixedCountNodeAllocator(session, maximumAllocationsPerNode); + allocators.add(allocator); + return allocator; + } + + private class FixedCountNodeAllocator + implements NodeAllocator + { + private final Session session; + private final int maximumAllocationsPerNode; + + @GuardedBy("this") + private final Map, NodeSelector> nodeSelectorCache = new HashMap<>(); + + @GuardedBy("this") + private final Map allocationCountMap = new HashMap<>(); + + @GuardedBy("this") + private final List pendingAcquires = new LinkedList<>(); + + public FixedCountNodeAllocator( + Session session, + int maximumAllocationsPerNode) + { + this.session = requireNonNull(session, "session is null"); + this.maximumAllocationsPerNode = maximumAllocationsPerNode; + } + + @Override + public synchronized ListenableFuture acquire(NodeRequirements requirements) + { + try { + Optional node = tryAcquireNode(requirements); + if (node.isPresent()) { + return immediateFuture(node.get()); + } + } + catch (RuntimeException e) { + return immediateFailedFuture(e); + } + + SettableFuture future = SettableFuture.create(); + PendingAcquire pendingAcquire = new PendingAcquire(requirements, future); + pendingAcquires.add(pendingAcquire); + + return future; + } + + @Override + public void release(InternalNode node) + { + releaseNodeInternal(node); + processPendingAcquires(); + } + + public void updateNodes() + { + processPendingAcquires(); + } + + private synchronized Optional tryAcquireNode(NodeRequirements requirements) + { + NodeSelector nodeSelector = nodeSelectorCache.computeIfAbsent(requirements.getCatalogName(), catalogName -> nodeScheduler.createNodeSelector(session, catalogName)); + + List nodes = nodeSelector.allNodes(); + if (nodes.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + List nodesMatchingRequirements = nodes.stream() + .filter(node -> requirements.getAddresses().isEmpty() || requirements.getAddresses().contains(node.getHostAndPort())) + .collect(toImmutableList()); + + if (nodesMatchingRequirements.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + Optional selectedNode = nodesMatchingRequirements.stream() + .filter(node -> allocationCountMap.getOrDefault(node, 0) < maximumAllocationsPerNode) + .min(comparing(node -> allocationCountMap.getOrDefault(node, 0))); + + if (selectedNode.isEmpty()) { + return Optional.empty(); + } + + allocationCountMap.compute(selectedNode.get(), (key, value) -> value == null ? 1 : value + 1); + return selectedNode; + } + + private synchronized void releaseNodeInternal(InternalNode node) + { + int allocationCount = allocationCountMap.compute(node, (key, value) -> value == null ? 0 : value - 1); + checkState(allocationCount >= 0, "allocation count for node %s is expected to be greater than or equal to zero: %s", node, allocationCount); + } + + private void processPendingAcquires() + { + verify(!Thread.holdsLock(this)); + + IdentityHashMap assignedNodes = new IdentityHashMap<>(); + IdentityHashMap failures = new IdentityHashMap<>(); + synchronized (this) { + Iterator iterator = pendingAcquires.iterator(); + while (iterator.hasNext()) { + PendingAcquire pendingAcquire = iterator.next(); + if (pendingAcquire.getFuture().isCancelled()) { + iterator.remove(); + continue; + } + try { + Optional node = tryAcquireNode(pendingAcquire.getNodeRequirements()); + if (node.isPresent()) { + iterator.remove(); + assignedNodes.put(pendingAcquire, node.get()); + } + } + catch (RuntimeException e) { + iterator.remove(); + failures.put(pendingAcquire, e); + } + } + } + + // set futures outside of critical section + assignedNodes.forEach((pendingAcquire, node) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.set(node); + if (future.isCancelled()) { + releaseNodeInternal(node); + } + }); + + failures.forEach((pendingAcquire, failure) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.setException(failure); + }); + } + + @Override + public synchronized void close() + { + allocators.remove(this); + } + } + + private static class PendingAcquire + { + private final NodeRequirements nodeRequirements; + private final SettableFuture future; + + private PendingAcquire(NodeRequirements nodeRequirements, SettableFuture future) + { + this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); + this.future = requireNonNull(future, "future is null"); + } + + public NodeRequirements getNodeRequirements() + { + return nodeRequirements; + } + + public SettableFuture getFuture() + { + return future; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java index 778c059982e6..f7aaa038f934 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java @@ -25,8 +25,6 @@ public interface NodeAllocator void release(InternalNode node); - void updateNodes(); - @Override void close(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocatorService.java new file mode 100644 index 000000000000..faea8c229f25 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocatorService.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.trino.Session; + +public interface NodeAllocatorService +{ + NodeAllocator getNodeAllocator(Session session); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index 85647ef02433..4baf816283fe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -97,7 +97,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -175,6 +174,7 @@ public class SqlQueryScheduler private final QueryStateMachine queryStateMachine; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final NodeAllocatorService nodeAllocatorService; private final int splitBatchSize; private final ExecutorService executor; private final ScheduledExecutorService schedulerExecutor; @@ -210,6 +210,7 @@ public SqlQueryScheduler( SubPlan plan, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + NodeAllocatorService nodeAllocatorService, RemoteTaskFactory remoteTaskFactory, boolean summarizeTaskInfo, int splitBatchSize, @@ -231,6 +232,7 @@ public SqlQueryScheduler( this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.splitBatchSize = splitBatchSize; this.executor = requireNonNull(queryExecutor, "queryExecutor is null"); this.schedulerExecutor = requireNonNull(schedulerExecutor, "schedulerExecutor is null"); @@ -342,7 +344,7 @@ private synchronized Optional createDistributedStage maxRetryAttempts, schedulerExecutor, schedulerStats, - nodeScheduler); + nodeAllocatorService); break; case QUERY: case NONE: @@ -1727,7 +1729,6 @@ private static class FaultTolerantDistributedStagesScheduler private final List schedulers; private final SplitSchedulerStats schedulerStats; private final NodeAllocator nodeAllocator; - private final ScheduledFuture nodeUpdateTask; private final AtomicBoolean started = new AtomicBoolean(); @@ -1743,7 +1744,7 @@ public static FaultTolerantDistributedStagesScheduler create( int retryAttempts, ScheduledExecutorService scheduledExecutorService, SplitSchedulerStats schedulerStats, - NodeScheduler nodeScheduler) + NodeAllocatorService nodeAllocatorService) { taskDescriptorStorage.initialize(queryStateMachine.getQueryId()); queryStateMachine.addStateChangeListener(state -> { @@ -1760,9 +1761,7 @@ public static FaultTolerantDistributedStagesScheduler create( ImmutableList.Builder schedulers = ImmutableList.builder(); Map exchanges = new HashMap<>(); - - FixedCountNodeAllocator nodeAllocator = new FixedCountNodeAllocator(nodeScheduler, session, 1); - ScheduledFuture nodeUpdateTask = scheduledExecutorService.scheduleAtFixedRate(nodeAllocator::updateNodes, 5, 5, SECONDS); + NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(session); try { // root to children order @@ -1830,8 +1829,7 @@ public static FaultTolerantDistributedStagesScheduler create( queryStateMachine, schedulers.build(), schedulerStats, - nodeAllocator, - nodeUpdateTask); + nodeAllocator); } catch (Throwable t) { for (FaultTolerantStageScheduler scheduler : schedulers.build()) { @@ -1845,7 +1843,6 @@ public static FaultTolerantDistributedStagesScheduler create( } } - nodeUpdateTask.cancel(true); try { nodeAllocator.close(); } @@ -1933,15 +1930,13 @@ private FaultTolerantDistributedStagesScheduler( QueryStateMachine queryStateMachine, List schedulers, SplitSchedulerStats schedulerStats, - NodeAllocator nodeAllocator, - ScheduledFuture nodeUpdateTask) + NodeAllocator nodeAllocator) { this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.schedulers = requireNonNull(schedulers, "schedulers is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); - this.nodeUpdateTask = requireNonNull(nodeUpdateTask, "nodeUpdateTask is null"); } @Override @@ -2051,7 +2046,6 @@ public void abort() private void closeNodeAllocator() { - nodeUpdateTask.cancel(true); try { nodeAllocator.close(); } diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index ee5c0496f035..b002c719c3ab 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -60,6 +60,8 @@ import io.trino.execution.resourcegroups.InternalResourceGroupManager; import io.trino.execution.resourcegroups.LegacyResourceGroupConfigurationManager; import io.trino.execution.resourcegroups.ResourceGroupManager; +import io.trino.execution.scheduler.FixedCountNodeAllocatorService; +import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.StageTaskSourceFactory; import io.trino.execution.scheduler.TaskDescriptorStorage; @@ -209,6 +211,9 @@ protected void setup(Binder binder) bindLowMemoryKiller(LowMemoryKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES, TotalReservationOnBlockedNodesLowMemoryKiller.class); newExporter(binder).export(ClusterMemoryManager.class).withGeneratedName(); + // node allocator + binder.bind(NodeAllocatorService.class).to(FixedCountNodeAllocatorService.class).in(Scopes.SINGLETON); + // node monitor binder.bind(ClusterSizeMonitor.class).in(Scopes.SINGLETON); newExporter(binder).export(ClusterSizeMonitor.class).withGeneratedName(); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index eb5bd3b9475d..775d49bb9e2a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -52,6 +52,7 @@ import io.trino.testing.TestingMetadata.TestingColumnHandle; import io.trino.util.FinalizerService; import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -103,6 +104,7 @@ public class TestFaultTolerantStageScheduler private FinalizerService finalizerService; private NodeTaskMap nodeTaskMap; + private FixedCountNodeAllocatorService nodeAllocatorService; @BeforeClass public void beforeClass() @@ -122,6 +124,21 @@ public void afterClass() } } + private void setupNodeAllocatorService(TestingNodeSupplier nodeSupplier) + { + shutdownNodeAllocatorService(); // just in case + nodeAllocatorService = new FixedCountNodeAllocatorService(new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, nodeSupplier))); + } + + @AfterMethod(alwaysRun = true) + public void shutdownNodeAllocatorService() + { + if (nodeAllocatorService != null) { + nodeAllocatorService.stop(); + } + nodeAllocatorService = null; + } + @Test public void testHappyPath() throws Exception @@ -132,134 +149,137 @@ public void testHappyPath() NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG), NODE_3, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sinkExchange = new TestingExchange(false); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - createNodeAllocator(nodeSupplier), - TaskLifecycleListener.NO_OP, - Optional.of(sinkExchange), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 2); - - ListenableFuture blocked = scheduler.isBlocked(); - assertUnblocked(blocked); - - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on first source exchange - assertBlocked(blocked); - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - // still blocked on the second source exchange - assertBlocked(blocked); - assertBlocked(scheduler.isBlocked()); - - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - // now unblocked - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on node allocation - assertBlocked(blocked); - - // not all tasks have been enumerated yet - assertFalse(sinkExchange.isNoMoreSinks()); - - Map tasks = remoteTaskFactory.getTasks(); - // one task per node - assertThat(tasks).hasSize(3); - assertThat(tasks).containsKey(getTaskId(0, 0)); - assertThat(tasks).containsKey(getTaskId(1, 0)); - assertThat(tasks).containsKey(getTaskId(2, 0)); - - TestingRemoteTask task = tasks.get(getTaskId(0, 0)); - // fail task for partition 0 - task.fail(new RuntimeException("some failure")); - - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - // schedule more tasks - scheduler.schedule(); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(4); - assertThat(tasks).containsKey(getTaskId(3, 0)); - - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); - - // finish some task - assertThat(tasks).containsKey(getTaskId(1, 0)); - tasks.get(getTaskId(1, 0)).finish(); - - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1)); - - // this will schedule failed task - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(5); - assertThat(tasks).containsKey(getTaskId(0, 1)); - - // finish some task - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(3, 0)); - tasks.get(getTaskId(3, 0)).finish(); - assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1), new TestingExchangeSinkHandle(3)); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.of(sinkExchange), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 2); - assertUnblocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + assertUnblocked(blocked); - // schedule the last task - scheduler.schedule(); + scheduler.schedule(); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(6); - assertThat(tasks).containsKey(getTaskId(4, 0)); + blocked = scheduler.isBlocked(); + // blocked on first source exchange + assertBlocked(blocked); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + // still blocked on the second source exchange + assertBlocked(blocked); + assertFalse(scheduler.isBlocked().isDone()); + + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + // now unblocked + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); + + scheduler.schedule(); + + blocked = scheduler.isBlocked(); + // blocked on node allocation + assertBlocked(blocked); + + // not all tasks have been enumerated yet + assertFalse(sinkExchange.isNoMoreSinks()); + + Map tasks = remoteTaskFactory.getTasks(); + // one task per node + assertThat(tasks).hasSize(3); + assertThat(tasks).containsKey(getTaskId(0, 0)); + assertThat(tasks).containsKey(getTaskId(1, 0)); + assertThat(tasks).containsKey(getTaskId(2, 0)); + + TestingRemoteTask task = tasks.get(getTaskId(0, 0)); + // fail task for partition 0 + task.fail(new RuntimeException("some failure")); + + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); + + // schedule more tasks + scheduler.schedule(); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(4); + assertThat(tasks).containsKey(getTaskId(3, 0)); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertBlocked(blocked); + + // finish some task + assertThat(tasks).containsKey(getTaskId(1, 0)); + tasks.get(getTaskId(1, 0)).finish(); + + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); + assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1)); + + // this will schedule failed task + scheduler.schedule(); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertBlocked(blocked); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(5); + assertThat(tasks).containsKey(getTaskId(0, 1)); + + // finish some task + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).containsKey(getTaskId(3, 0)); + tasks.get(getTaskId(3, 0)).finish(); + assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1), new TestingExchangeSinkHandle(3)); + + assertUnblocked(blocked); + + // schedule the last task + scheduler.schedule(); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(6); + assertThat(tasks).containsKey(getTaskId(4, 0)); + + // not finished yet, will be finished when all tasks succeed + assertFalse(scheduler.isFinished()); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertBlocked(blocked); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).containsKey(getTaskId(4, 0)); + // finish remaining tasks + tasks.get(getTaskId(0, 1)).finish(); + tasks.get(getTaskId(2, 0)).finish(); + tasks.get(getTaskId(4, 0)).finish(); - // not finished yet, will be finished when all tasks succeed - assertFalse(scheduler.isFinished()); + // now it's not blocked and finished + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); + assertThat(sinkExchange.getFinishedSinkHandles()).contains( + new TestingExchangeSinkHandle(0), + new TestingExchangeSinkHandle(1), + new TestingExchangeSinkHandle(2), + new TestingExchangeSinkHandle(3), + new TestingExchangeSinkHandle(4)); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(4, 0)); - // finish remaining tasks - tasks.get(getTaskId(0, 1)).finish(); - tasks.get(getTaskId(2, 0)).finish(); - tasks.get(getTaskId(4, 0)).finish(); - - // now it's not blocked and finished - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - assertThat(sinkExchange.getFinishedSinkHandles()).contains( - new TestingExchangeSinkHandle(0), - new TestingExchangeSinkHandle(1), - new TestingExchangeSinkHandle(2), - new TestingExchangeSinkHandle(3), - new TestingExchangeSinkHandle(4)); - - assertTrue(scheduler.isFinished()); + assertTrue(scheduler.isFinished()); + } } @Test @@ -271,37 +291,40 @@ public void testTaskLifecycleListener() TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingTaskLifecycleListener taskLifecycleListener = new TestingTaskLifecycleListener(); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - createNodeAllocator(nodeSupplier), - taskLifecycleListener, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 2); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + taskLifecycleListener, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 2); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); - assertBlocked(scheduler.isBlocked()); + scheduler.schedule(); + assertBlocked(scheduler.isBlocked()); - assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0)); + assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0)); - remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some exception")); + remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some exception")); - assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); - assertBlocked(scheduler.isBlocked()); + assertUnblocked(scheduler.isBlocked()); + scheduler.schedule(); + assertBlocked(scheduler.isBlocked()); - assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0), getTaskId(0, 1)); + assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0), getTaskId(0, 1)); + } } @Test @@ -313,46 +336,46 @@ public void testTaskFailure() TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - TaskLifecycleListener.NO_OP, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 0); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 0); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); + scheduler.schedule(); - ListenableFuture blocked = scheduler.isBlocked(); - // waiting on node acquisition - assertBlocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + // waiting on node acquisition + assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); + remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); - assertUnblocked(blocked); - assertUnblocked(acquireNode1); - assertUnblocked(acquireNode2); - assertTrue(acquireNode1.isDone()); - assertTrue(acquireNode2.isDone()); + assertUnblocked(blocked); + assertUnblocked(acquireNode1); + assertUnblocked(acquireNode2); - assertThatThrownBy(scheduler::schedule) - .hasMessageContaining("some failure"); + assertThatThrownBy(scheduler::schedule) + .hasMessageContaining("some failure"); - assertUnblocked(scheduler.isBlocked()); - assertFalse(scheduler.isFinished()); + assertUnblocked(scheduler.isBlocked()); + assertFalse(scheduler.isFinished()); + } } @Test @@ -364,43 +387,45 @@ public void testReportTaskFailure() TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - TaskLifecycleListener.NO_OP, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 1); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 1); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); + scheduler.schedule(); - ListenableFuture blocked = scheduler.isBlocked(); - // waiting for tasks to finish - assertBlocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + // waiting for tasks to finish + assertBlocked(blocked); - scheduler.reportTaskFailure(getTaskId(0, 0), new RuntimeException("some failure")); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); + scheduler.reportTaskFailure(getTaskId(0, 0), new RuntimeException("some failure")); + assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertUnblocked(blocked); - scheduler.schedule(); + assertUnblocked(blocked); + scheduler.schedule(); - assertThat(remoteTaskFactory.getTasks()).containsKey(getTaskId(0, 1)); + assertThat(remoteTaskFactory.getTasks()).containsKey(getTaskId(0, 1)); - remoteTaskFactory.getTasks().get(getTaskId(0, 1)).finish(); - remoteTaskFactory.getTasks().get(getTaskId(1, 0)).finish(); + remoteTaskFactory.getTasks().get(getTaskId(0, 1)).finish(); + remoteTaskFactory.getTasks().get(getTaskId(1, 0)).finish(); - assertUnblocked(scheduler.isBlocked()); - assertTrue(scheduler.isFinished()); + assertUnblocked(scheduler.isBlocked()); + assertTrue(scheduler.isFinished()); + } } @Test @@ -419,48 +444,50 @@ private void testCancellation(boolean abort) TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - TaskLifecycleListener.NO_OP, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 0); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 0); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); + scheduler.schedule(); - ListenableFuture blocked = scheduler.isBlocked(); - // waiting on node acquisition - assertBlocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + // waiting on node acquisition + assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - if (abort) { - scheduler.abort(); - } - else { - scheduler.cancel(); - } + if (abort) { + scheduler.abort(); + } + else { + scheduler.cancel(); + } - assertUnblocked(blocked); - assertUnblocked(acquireNode1); - assertUnblocked(acquireNode2); + assertUnblocked(blocked); + assertUnblocked(acquireNode1); + assertUnblocked(acquireNode2); - scheduler.schedule(); + scheduler.schedule(); - assertUnblocked(scheduler.isBlocked()); - assertFalse(scheduler.isFinished()); + assertUnblocked(scheduler.isBlocked()); + assertFalse(scheduler.isFinished()); + } } private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( @@ -562,12 +589,6 @@ private static List createSplits(int count) return ImmutableList.copyOf(limit(cycle(new Split(CATALOG, createRemoteSplit(), Lifespan.taskWide())), count)); } - private NodeAllocator createNodeAllocator(TestingNodeSupplier nodeSupplier) - { - NodeScheduler nodeScheduler = new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, nodeSupplier)); - return new FixedCountNodeAllocator(nodeScheduler, SESSION, 1); - } - private static TaskId getTaskId(int partitionId, int attemptId) { return new TaskId(STAGE_ID, partitionId, attemptId); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java index 569d1ad420ce..fb1067b8de41 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java @@ -23,6 +23,7 @@ import io.trino.execution.scheduler.TestingNodeSelectorFactory.TestingNodeSupplier; import io.trino.metadata.InternalNode; import io.trino.spi.HostAddress; +import org.testng.annotations.AfterMethod; import org.testng.annotations.Test; import java.net.URI; @@ -35,6 +36,8 @@ import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; +// uses mutable state +@Test(singleThreaded = true) public class TestFixedCountNodeAllocator { private static final Session SESSION = testSessionBuilder().build(); @@ -50,13 +53,31 @@ public class TestFixedCountNodeAllocator private static final CatalogName CATALOG_1 = new CatalogName("catalog1"); private static final CatalogName CATALOG_2 = new CatalogName("catalog2"); + private FixedCountNodeAllocatorService nodeAllocatorService; + + private void setupNodeAllocatorService(TestingNodeSupplier testingNodeSupplier) + { + shutdownNodeAllocatorService(); // just in case + nodeAllocatorService = new FixedCountNodeAllocatorService(new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier))); + } + + @AfterMethod(alwaysRun = true) + public void shutdownNodeAllocatorService() + { + if (nodeAllocatorService != null) { + nodeAllocatorService.stop(); + } + nodeAllocatorService = null; + } + @Test public void testSingleNode() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -70,7 +91,7 @@ public void testSingleNode() assertEquals(acquire2.get(), NODE_1); } - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 2)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -100,8 +121,9 @@ public void testMultipleNodes() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(), NODE_2, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -132,7 +154,7 @@ public void testMultipleNodes() assertEquals(acquire5.get(), NODE_1); } - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 2)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -189,7 +211,9 @@ public void testCatalogRequirement() NODE_2, ImmutableList.of(CATALOG_2), NODE_3, ImmutableList.of(CATALOG_1, CATALOG_2))); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + setupNodeAllocatorService(nodeSupplier); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); assertTrue(catalog1acquire1.isDone()); assertEquals(catalog1acquire1.get(), NODE_1); @@ -239,8 +263,9 @@ public void testCancellation() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -264,8 +289,9 @@ public void testAddNode() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -274,7 +300,7 @@ public void testAddNode() assertFalse(acquire2.isDone()); nodeSupplier.addNode(NODE_2, ImmutableList.of()); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); assertEquals(acquire2.get(10, SECONDS), NODE_2); } @@ -285,8 +311,9 @@ public void testRemoveNode() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -296,7 +323,7 @@ public void testRemoveNode() nodeSupplier.removeNode(NODE_1); nodeSupplier.addNode(NODE_2, ImmutableList.of()); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); assertEquals(acquire2.get(10, SECONDS), NODE_2); @@ -313,7 +340,9 @@ public void testAddressRequirement() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(), NODE_2, ImmutableList.of())); - try (FixedCountNodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + setupNodeAllocatorService(nodeSupplier); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_2); @@ -332,7 +361,7 @@ public void testAddressRequirement() .hasMessageContaining("No nodes available to run query"); nodeSupplier.addNode(NODE_3, ImmutableList.of()); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); assertTrue(acquire4.isDone()); @@ -342,7 +371,7 @@ public void testAddressRequirement() assertFalse(acquire5.isDone()); nodeSupplier.removeNode(NODE_3); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); assertTrue(acquire5.isDone()); assertThatThrownBy(acquire5::get) @@ -350,11 +379,6 @@ public void testAddressRequirement() } } - private FixedCountNodeAllocator createNodeAllocator(TestingNodeSupplier testingNodeSupplier, int maximumAllocationsPerNode) - { - return new FixedCountNodeAllocator(createNodeScheduler(testingNodeSupplier), SESSION, maximumAllocationsPerNode); - } - private NodeScheduler createNodeScheduler(TestingNodeSupplier testingNodeSupplier) { return new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier));