Skip to content

Commit

Permalink
Fix over acquiring of tasks in MultiThreadAgent
Browse files Browse the repository at this point in the history
Fixes #487.
  • Loading branch information
frsyuki committed Mar 3, 2017
1 parent d333526 commit 14c21f5
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.BlockingQueue;
import java.time.Duration;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand All @@ -23,8 +26,11 @@ public class MultiThreadAgent
private final TaskServerApi taskServer;
private final OperatorManager runner;
private final ErrorReporter errorReporter;
private final ThreadPoolExecutor executor;

private final Object newTaskLock = new Object();
private final BlockingQueue<Runnable> executorQueue;
private final ThreadPoolExecutor executor;

private volatile boolean stop = false;

public MultiThreadAgent(
Expand All @@ -37,15 +43,28 @@ public MultiThreadAgent(
this.taskServer = taskServer;
this.runner = runner;
this.errorReporter = errorReporter;

ThreadFactory threadFactory = new ThreadFactoryBuilder()
.setDaemon(false) // make them non-daemon threads so that shutting down agent doesn't kill operator execution
.setNameFormat("task-thread-%d")
.build();

if (config.getMaxThreads() > 0) {
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(config.getMaxThreads(), threadFactory);
this.executorQueue = new LinkedBlockingQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
0, config.getMaxThreads(),
60L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
else {
this.executor = (ThreadPoolExecutor) Executors.newCachedThreadPool(threadFactory);
// If there're no upper limit on number of threads, queue actually doesn't need to store entries.
// Instead, executor.submit() blocks until a thread starts and takes it.
// SynchronousQueue.size() always returns 0.
this.executorQueue = new SynchronousQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
0, Integer.MAX_VALUE,
60L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
}

Expand All @@ -54,15 +73,15 @@ public void shutdown(Optional<Duration> maximumCompletionWait)
{
stop = true;
taskServer.interruptLocalWait();
int activeCount;
int maximumPossibleActiveTaskCount;
synchronized (newTaskLock) {
// synchronize newTaskLock not to reject task execution after acquiring them from taskServer
executor.shutdown();
activeCount = executor.getActiveCount();
maximumPossibleActiveTaskCount = executorQueue.size() + executor.getActiveCount();
newTaskLock.notifyAll();
}
if (activeCount > 0) {
logger.info("Waiting for completion of {} running tasks...", activeCount);
if (maximumPossibleActiveTaskCount > 0) {
logger.info("Waiting for completion of {} running tasks...", maximumPossibleActiveTaskCount);
}
if (maximumCompletionWait.isPresent()) {
long seconds = maximumCompletionWait.get().getSeconds();
Expand All @@ -86,9 +105,11 @@ public void run()
if (executor.isShutdown()) {
break;
}
int max = Math.min(executor.getMaximumPoolSize() - executor.getActiveCount(), 10);
if (max > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(max, agentId, config.getLockRetentionTime(), 1000);
int maximumPossibleActiveTaskCount = executorQueue.size() + executor.getActiveCount();
int guaranteedAvaialbleThreads = executor.getMaximumPoolSize() - maximumPossibleActiveTaskCount;
int maxAcquire = Math.min(guaranteedAvaialbleThreads, 10);
if (maxAcquire > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(maxAcquire, agentId, config.getLockRetentionTime(), 1000);
for (TaskRequest req : reqs) {
executor.submit(() -> {
try {
Expand Down
62 changes: 62 additions & 0 deletions digdag-tests/src/test/java/acceptance/AgentOverAcquireIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package acceptance;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assume.assumeThat;
import static utils.TestUtils.copyResource;
import static utils.TestUtils.main;
import utils.CommandStatus;

public class AgentOverAcquireIT
{
@Rule
public TemporaryFolder folder = new TemporaryFolder();

private Path projectDir;
private Path config;
private Path outdir;

@Before
public void setUp()
throws Exception
{
projectDir = folder.getRoot().toPath();
config = folder.newFile().toPath();

outdir = projectDir.resolve("outdir");
Files.createDirectories(outdir);
}

@Test
public void testOverAcquire()
throws Exception
{
assumeThat(true, is(false)); // disabled by default to avoid too long execution time.

copyResource("acceptance/over_acquire/over_acquire.dig", projectDir.resolve("over_acquire.dig"));

CommandStatus runStatus = main("run",
"-o", projectDir.toString(),
"--config", config.toString(),
"--project", projectDir.toString(),
"-X", "agent.heartbeat-interval=10",
"-X", "agent.lock-retention-time=20",
"-X", "agent.max-task-threads=5",
"-p", "outdir=" + outdir,
"over_acquire.dig");
assertThat(runStatus.errUtf8(), runStatus.code(), is(0));

for (int i = 0; i < 20; i++) {
String one = new String(Files.readAllBytes(outdir.resolve(Integer.toString(i))), UTF_8);
assertThat(one, is("1"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

+loop:
loop>: 20
_parallel: true
_do:
sh>: sleep 30 && /bin/echo -n 1 >> ${outdir}/${i}

0 comments on commit 14c21f5

Please sign in to comment.