Skip to content

Commit

Permalink
Merge pull request #496 from treasure-data/max-threads-overacquire-fix
Browse files Browse the repository at this point in the history
Fix over acquiring of tasks in MultiThreadAgent
  • Loading branch information
frsyuki authored Mar 13, 2017
2 parents 47d74e0 + 6926ec0 commit 248f7a6
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package io.digdag.core.agent;

import com.google.common.base.Optional;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.digdag.core.ErrorReporter;
import io.digdag.spi.TaskRequest;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.time.Duration;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.digdag.spi.TaskRequest;
import io.digdag.core.ErrorReporter;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -23,8 +27,12 @@ public class MultiThreadAgent
private final TaskServerApi taskServer;
private final OperatorManager runner;
private final ErrorReporter errorReporter;

private final Object addActiveTaskLock = new Object();
private final BlockingQueue<Runnable> executorQueue;
private final ThreadPoolExecutor executor;
private final Object newTaskLock = new Object();
private final AtomicInteger activeTaskCount = new AtomicInteger(0);

private volatile boolean stop = false;

public MultiThreadAgent(
Expand All @@ -37,15 +45,28 @@ public MultiThreadAgent(
this.taskServer = taskServer;
this.runner = runner;
this.errorReporter = errorReporter;

ThreadFactory threadFactory = new ThreadFactoryBuilder()
.setDaemon(false) // make them non-daemon threads so that shutting down agent doesn't kill operator execution
.setNameFormat("task-thread-%d")
.build();

if (config.getMaxThreads() > 0) {
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(config.getMaxThreads(), threadFactory);
this.executorQueue = new LinkedBlockingQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
config.getMaxThreads(), config.getMaxThreads(),
0L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
else {
this.executor = (ThreadPoolExecutor) Executors.newCachedThreadPool(threadFactory);
// If there're no upper limit on number of threads, queue actually doesn't need to store entries.
// Instead, executor.submit() blocks until a thread starts and takes it.
// SynchronousQueue.size() always returns 0.
this.executorQueue = new SynchronousQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
0, Integer.MAX_VALUE,
60L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
}

Expand All @@ -54,15 +75,15 @@ public void shutdown(Optional<Duration> maximumCompletionWait)
{
stop = true;
taskServer.interruptLocalWait();
int activeCount;
synchronized (newTaskLock) {
// synchronize newTaskLock not to reject task execution after acquiring them from taskServer
executor.shutdown();
activeCount = executor.getActiveCount();
newTaskLock.notifyAll();
int maximumActiveTasks;
synchronized (addActiveTaskLock) {
// synchronize addActiveTaskLock not to reject task execution after acquiring them from taskServer
executor.shutdown(); // Since here, no one can increase activeTaskCount.
maximumActiveTasks = activeTaskCount.get(); /// Now get the maximum count.
addActiveTaskLock.notifyAll();
}
if (activeCount > 0) {
logger.info("Waiting for completion of {} running tasks...", activeCount);
if (maximumActiveTasks > 0) {
logger.info("Waiting for completion of {} running tasks...", maximumActiveTasks);
}
if (maximumCompletionWait.isPresent()) {
long seconds = maximumCompletionWait.get().getSeconds();
Expand All @@ -82,13 +103,18 @@ public void run()
{
while (!stop) {
try {
synchronized (newTaskLock) {
synchronized (addActiveTaskLock) {
if (executor.isShutdown()) {
break;
}
int max = Math.min(executor.getMaximumPoolSize() - executor.getActiveCount(), 10);
if (max > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(max, agentId, config.getLockRetentionTime(), 1000);
// Because addActiveTaskLock is locked, no one increases activeTaskCount in this synchronized block. Now get the maximum count.
int maximumActiveTasks = activeTaskCount.get();
// Because the maximum count doesn't increase, here can know that at least N number of threads are idling.
int guaranteedAvaialbleThreads = executor.getMaximumPoolSize() - maximumActiveTasks;
// Acquire at most guaranteedAvaialbleThreads or 10. This guarantees that all tasks start immediately.
int maxAcquire = Math.min(guaranteedAvaialbleThreads, 10);
if (maxAcquire > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(maxAcquire, agentId, config.getLockRetentionTime(), 1000);
for (TaskRequest req : reqs) {
executor.submit(() -> {
try {
Expand All @@ -98,12 +124,16 @@ public void run()
logger.error("Uncaught exception. Task queue will detect this failure and this task will be retried later.", t);
errorReporter.reportUncaughtError(t);
}
finally {
activeTaskCount.decrementAndGet();
}
});
activeTaskCount.incrementAndGet();
}
}
else {
// no executor thread is available. sleep for a while until a task execution finishes
newTaskLock.wait(500);
addActiveTaskLock.wait(500);
}
}
}
Expand Down
62 changes: 62 additions & 0 deletions digdag-tests/src/test/java/acceptance/AgentOverAcquireIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package acceptance;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assume.assumeThat;
import static utils.TestUtils.copyResource;
import static utils.TestUtils.main;
import utils.CommandStatus;

public class AgentOverAcquireIT
{
@Rule
public TemporaryFolder folder = new TemporaryFolder();

private Path projectDir;
private Path config;
private Path outdir;

@Before
public void setUp()
throws Exception
{
projectDir = folder.getRoot().toPath();
config = folder.newFile().toPath();

outdir = projectDir.resolve("outdir");
Files.createDirectories(outdir);
}

@Test
public void testOverAcquire()
throws Exception
{
assumeThat(true, is(false)); // disabled by default to avoid too long execution time.

copyResource("acceptance/over_acquire/over_acquire.dig", projectDir.resolve("over_acquire.dig"));

CommandStatus runStatus = main("run",
"-o", projectDir.toString(),
"--config", config.toString(),
"--project", projectDir.toString(),
"-X", "agent.heartbeat-interval=5",
"-X", "agent.lock-retention-time=20",
"-X", "agent.max-task-threads=5",
"-p", "outdir=" + outdir,
"over_acquire.dig");
assertThat(runStatus.errUtf8(), runStatus.code(), is(0));

for (int i = 0; i < 20; i++) {
String one = new String(Files.readAllBytes(outdir.resolve(Integer.toString(i))), UTF_8).trim();
assertThat(one, is("1"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

+loop:
loop>: 20
_parallel: true
_do:
sh>: sleep 30 && echo 1 >> ${outdir}/${i}

0 comments on commit 248f7a6

Please sign in to comment.