diff --git a/digdag-core/src/main/java/io/digdag/core/agent/MultiThreadAgent.java b/digdag-core/src/main/java/io/digdag/core/agent/MultiThreadAgent.java index fbad9025c1..3805b31e90 100644 --- a/digdag-core/src/main/java/io/digdag/core/agent/MultiThreadAgent.java +++ b/digdag-core/src/main/java/io/digdag/core/agent/MultiThreadAgent.java @@ -1,15 +1,19 @@ package io.digdag.core.agent; +import com.google.common.base.Optional; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.digdag.core.ErrorReporter; +import io.digdag.spi.TaskRequest; +import java.time.Duration; import java.util.List; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executors; -import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.time.Duration; -import com.google.common.base.Optional; -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.digdag.spi.TaskRequest; -import io.digdag.core.ErrorReporter; +import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,8 +27,12 @@ public class MultiThreadAgent private final TaskServerApi taskServer; private final OperatorManager runner; private final ErrorReporter errorReporter; + + private final Object addActiveTaskLock = new Object(); + private final BlockingQueue executorQueue; private final ThreadPoolExecutor executor; - private final Object newTaskLock = new Object(); + private final AtomicInteger activeTaskCount = new AtomicInteger(0); + private volatile boolean stop = false; public MultiThreadAgent( @@ -37,15 +45,28 @@ public MultiThreadAgent( this.taskServer = taskServer; this.runner = runner; this.errorReporter = errorReporter; + ThreadFactory threadFactory = new ThreadFactoryBuilder() .setDaemon(false) // make them non-daemon threads so that shutting down agent doesn't kill operator execution .setNameFormat("task-thread-%d") .build(); + if (config.getMaxThreads() > 0) { - this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(config.getMaxThreads(), threadFactory); + this.executorQueue = new LinkedBlockingQueue(); + this.executor = new ThreadPoolExecutor( + config.getMaxThreads(), config.getMaxThreads(), + 0L, TimeUnit.SECONDS, + executorQueue, threadFactory); } else { - this.executor = (ThreadPoolExecutor) Executors.newCachedThreadPool(threadFactory); + // If there're no upper limit on number of threads, queue actually doesn't need to store entries. + // Instead, executor.submit() blocks until a thread starts and takes it. + // SynchronousQueue.size() always returns 0. + this.executorQueue = new SynchronousQueue(); + this.executor = new ThreadPoolExecutor( + 0, Integer.MAX_VALUE, + 60L, TimeUnit.SECONDS, + executorQueue, threadFactory); } } @@ -54,15 +75,15 @@ public void shutdown(Optional maximumCompletionWait) { stop = true; taskServer.interruptLocalWait(); - int activeCount; - synchronized (newTaskLock) { - // synchronize newTaskLock not to reject task execution after acquiring them from taskServer - executor.shutdown(); - activeCount = executor.getActiveCount(); - newTaskLock.notifyAll(); + int maximumActiveTasks; + synchronized (addActiveTaskLock) { + // synchronize addActiveTaskLock not to reject task execution after acquiring them from taskServer + executor.shutdown(); // Since here, no one can increase activeTaskCount. + maximumActiveTasks = activeTaskCount.get(); /// Now get the maximum count. + addActiveTaskLock.notifyAll(); } - if (activeCount > 0) { - logger.info("Waiting for completion of {} running tasks...", activeCount); + if (maximumActiveTasks > 0) { + logger.info("Waiting for completion of {} running tasks...", maximumActiveTasks); } if (maximumCompletionWait.isPresent()) { long seconds = maximumCompletionWait.get().getSeconds(); @@ -82,13 +103,18 @@ public void run() { while (!stop) { try { - synchronized (newTaskLock) { + synchronized (addActiveTaskLock) { if (executor.isShutdown()) { break; } - int max = Math.min(executor.getMaximumPoolSize() - executor.getActiveCount(), 10); - if (max > 0) { - List reqs = taskServer.lockSharedAgentTasks(max, agentId, config.getLockRetentionTime(), 1000); + // Because addActiveTaskLock is locked, no one increases activeTaskCount in this synchronized block. Now get the maximum count. + int maximumActiveTasks = activeTaskCount.get(); + // Because the maximum count doesn't increase, here can know that at least N number of threads are idling. + int guaranteedAvaialbleThreads = executor.getMaximumPoolSize() - maximumActiveTasks; + // Acquire at most guaranteedAvaialbleThreads or 10. This guarantees that all tasks start immediately. + int maxAcquire = Math.min(guaranteedAvaialbleThreads, 10); + if (maxAcquire > 0) { + List reqs = taskServer.lockSharedAgentTasks(maxAcquire, agentId, config.getLockRetentionTime(), 1000); for (TaskRequest req : reqs) { executor.submit(() -> { try { @@ -98,12 +124,16 @@ public void run() logger.error("Uncaught exception. Task queue will detect this failure and this task will be retried later.", t); errorReporter.reportUncaughtError(t); } + finally { + activeTaskCount.decrementAndGet(); + } }); + activeTaskCount.incrementAndGet(); } } else { // no executor thread is available. sleep for a while until a task execution finishes - newTaskLock.wait(500); + addActiveTaskLock.wait(500); } } } diff --git a/digdag-tests/src/test/java/acceptance/AgentOverAcquireIT.java b/digdag-tests/src/test/java/acceptance/AgentOverAcquireIT.java new file mode 100644 index 0000000000..9c194670eb --- /dev/null +++ b/digdag-tests/src/test/java/acceptance/AgentOverAcquireIT.java @@ -0,0 +1,62 @@ +package acceptance; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.junit.Assume.assumeThat; +import static utils.TestUtils.copyResource; +import static utils.TestUtils.main; +import utils.CommandStatus; + +public class AgentOverAcquireIT +{ + @Rule + public TemporaryFolder folder = new TemporaryFolder(); + + private Path projectDir; + private Path config; + private Path outdir; + + @Before + public void setUp() + throws Exception + { + projectDir = folder.getRoot().toPath(); + config = folder.newFile().toPath(); + + outdir = projectDir.resolve("outdir"); + Files.createDirectories(outdir); + } + + @Test + public void testOverAcquire() + throws Exception + { + assumeThat(true, is(false)); // disabled by default to avoid too long execution time. + + copyResource("acceptance/over_acquire/over_acquire.dig", projectDir.resolve("over_acquire.dig")); + + CommandStatus runStatus = main("run", + "-o", projectDir.toString(), + "--config", config.toString(), + "--project", projectDir.toString(), + "-X", "agent.heartbeat-interval=5", + "-X", "agent.lock-retention-time=20", + "-X", "agent.max-task-threads=5", + "-p", "outdir=" + outdir, + "over_acquire.dig"); + assertThat(runStatus.errUtf8(), runStatus.code(), is(0)); + + for (int i = 0; i < 20; i++) { + String one = new String(Files.readAllBytes(outdir.resolve(Integer.toString(i))), UTF_8).trim(); + assertThat(one, is("1")); + } + } +} diff --git a/digdag-tests/src/test/resources/acceptance/over_acquire/over_acquire.dig b/digdag-tests/src/test/resources/acceptance/over_acquire/over_acquire.dig new file mode 100644 index 0000000000..98a2f49a77 --- /dev/null +++ b/digdag-tests/src/test/resources/acceptance/over_acquire/over_acquire.dig @@ -0,0 +1,7 @@ + ++loop: + loop>: 20 + _parallel: true + _do: + sh>: sleep 30 && echo 1 >> ${outdir}/${i} +