Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix over acquiring of tasks in MultiThreadAgent #496

Merged
merged 4 commits into from
Mar 13, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.BlockingQueue;
import java.time.Duration;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand All @@ -23,8 +26,12 @@ public class MultiThreadAgent
private final TaskServerApi taskServer;
private final OperatorManager runner;
private final ErrorReporter errorReporter;

private final Object taskCountLock = new Object();
private final BlockingQueue<Runnable> executorQueue;
private final ThreadPoolExecutor executor;
private final Object newTaskLock = new Object();
private volatile int activeTaskCount = 0;

private volatile boolean stop = false;

public MultiThreadAgent(
Expand All @@ -37,15 +44,28 @@ public MultiThreadAgent(
this.taskServer = taskServer;
this.runner = runner;
this.errorReporter = errorReporter;

ThreadFactory threadFactory = new ThreadFactoryBuilder()
.setDaemon(false) // make them non-daemon threads so that shutting down agent doesn't kill operator execution
.setNameFormat("task-thread-%d")
.build();

if (config.getMaxThreads() > 0) {
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(config.getMaxThreads(), threadFactory);
this.executorQueue = new LinkedBlockingQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
config.getMaxThreads(), config.getMaxThreads(),
0L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
else {
this.executor = (ThreadPoolExecutor) Executors.newCachedThreadPool(threadFactory);
// If there're no upper limit on number of threads, queue actually doesn't need to store entries.
// Instead, executor.submit() blocks until a thread starts and takes it.
// SynchronousQueue.size() always returns 0.
this.executorQueue = new SynchronousQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
0, Integer.MAX_VALUE,
60L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
}

Expand All @@ -54,15 +74,15 @@ public void shutdown(Optional<Duration> maximumCompletionWait)
{
stop = true;
taskServer.interruptLocalWait();
int activeCount;
synchronized (newTaskLock) {
// synchronize newTaskLock not to reject task execution after acquiring them from taskServer
int activeTaskCountSnapshot;
synchronized (taskCountLock) {
// synchronize taskCountLock not to reject task execution after acquiring them from taskServer
executor.shutdown();
activeCount = executor.getActiveCount();
newTaskLock.notifyAll();
activeTaskCountSnapshot = activeTaskCount;
taskCountLock.notifyAll();
}
if (activeCount > 0) {
logger.info("Waiting for completion of {} running tasks...", activeCount);
if (activeTaskCountSnapshot > 0) {
logger.info("Waiting for completion of {} running tasks...", activeTaskCountSnapshot);
}
if (maximumCompletionWait.isPresent()) {
long seconds = maximumCompletionWait.get().getSeconds();
Expand All @@ -82,13 +102,13 @@ public void run()
{
while (!stop) {
try {
synchronized (newTaskLock) {
synchronized (taskCountLock) {
if (executor.isShutdown()) {
break;
}
int max = Math.min(executor.getMaximumPoolSize() - executor.getActiveCount(), 10);
if (max > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(max, agentId, config.getLockRetentionTime(), 1000);
int maxAcquire = Math.min(executor.getMaximumPoolSize() - activeTaskCount, 10);
if (maxAcquire > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(maxAcquire, agentId, config.getLockRetentionTime(), 1000);
for (TaskRequest req : reqs) {
executor.submit(() -> {
try {
Expand All @@ -98,12 +118,18 @@ public void run()
logger.error("Uncaught exception. Task queue will detect this failure and this task will be retried later.", t);
errorReporter.reportUncaughtError(t);
}
finally {
synchronized (taskCountLock) {
activeTaskCount--;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using AtomicInteger makes it a bit simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With SynchronousQueue, calling executor.submit() can be blocked holding taskCountLock. As a result, a thread that tries to decrement the counter here can be potentially blocked forever?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SynchronousQueue is used with maximumPoolSize: Integer.MAX_VALUE. So it's usually okay. But it seems a bit naive to me...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point 👍 I changed the code to use AtomicInteger. Now still AgentOverAcquireIT passes.

}
}
});
activeTaskCount++;
}
}
else {
// no executor thread is available. sleep for a while until a task execution finishes
newTaskLock.wait(500);
taskCountLock.wait(500);
}
}
}
Expand Down
62 changes: 62 additions & 0 deletions digdag-tests/src/test/java/acceptance/AgentOverAcquireIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package acceptance;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assume.assumeThat;
import static utils.TestUtils.copyResource;
import static utils.TestUtils.main;
import utils.CommandStatus;

public class AgentOverAcquireIT
{
@Rule
public TemporaryFolder folder = new TemporaryFolder();

private Path projectDir;
private Path config;
private Path outdir;

@Before
public void setUp()
throws Exception
{
projectDir = folder.getRoot().toPath();
config = folder.newFile().toPath();

outdir = projectDir.resolve("outdir");
Files.createDirectories(outdir);
}

@Test
public void testOverAcquire()
throws Exception
{
assumeThat(true, is(false)); // disabled by default to avoid too long execution time.

copyResource("acceptance/over_acquire/over_acquire.dig", projectDir.resolve("over_acquire.dig"));

CommandStatus runStatus = main("run",
"-o", projectDir.toString(),
"--config", config.toString(),
"--project", projectDir.toString(),
"-X", "agent.heartbeat-interval=5",
"-X", "agent.lock-retention-time=20",
"-X", "agent.max-task-threads=5",
"-p", "outdir=" + outdir,
"over_acquire.dig");
assertThat(runStatus.errUtf8(), runStatus.code(), is(0));

for (int i = 0; i < 20; i++) {
String one = new String(Files.readAllBytes(outdir.resolve(Integer.toString(i))), UTF_8).trim();
assertThat(one, is("1"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

+loop:
loop>: 20
_parallel: true
_do:
sh>: sleep 30 && echo 1 >> ${outdir}/${i}