Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix over acquiring of tasks in MultiThreadAgent #496

Merged
merged 4 commits into from
Mar 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package io.digdag.core.agent;

import com.google.common.base.Optional;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.digdag.core.ErrorReporter;
import io.digdag.spi.TaskRequest;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.time.Duration;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.digdag.spi.TaskRequest;
import io.digdag.core.ErrorReporter;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -23,8 +27,12 @@ public class MultiThreadAgent
private final TaskServerApi taskServer;
private final OperatorManager runner;
private final ErrorReporter errorReporter;

private final Object addActiveTaskLock = new Object();
private final BlockingQueue<Runnable> executorQueue;
private final ThreadPoolExecutor executor;
private final Object newTaskLock = new Object();
private final AtomicInteger activeTaskCount = new AtomicInteger(0);

private volatile boolean stop = false;

public MultiThreadAgent(
Expand All @@ -37,15 +45,28 @@ public MultiThreadAgent(
this.taskServer = taskServer;
this.runner = runner;
this.errorReporter = errorReporter;

ThreadFactory threadFactory = new ThreadFactoryBuilder()
.setDaemon(false) // make them non-daemon threads so that shutting down agent doesn't kill operator execution
.setNameFormat("task-thread-%d")
.build();

if (config.getMaxThreads() > 0) {
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(config.getMaxThreads(), threadFactory);
this.executorQueue = new LinkedBlockingQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
config.getMaxThreads(), config.getMaxThreads(),
0L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
else {
this.executor = (ThreadPoolExecutor) Executors.newCachedThreadPool(threadFactory);
// If there're no upper limit on number of threads, queue actually doesn't need to store entries.
// Instead, executor.submit() blocks until a thread starts and takes it.
// SynchronousQueue.size() always returns 0.
this.executorQueue = new SynchronousQueue<Runnable>();
this.executor = new ThreadPoolExecutor(
0, Integer.MAX_VALUE,
60L, TimeUnit.SECONDS,
executorQueue, threadFactory);
}
}

Expand All @@ -54,15 +75,15 @@ public void shutdown(Optional<Duration> maximumCompletionWait)
{
stop = true;
taskServer.interruptLocalWait();
int activeCount;
synchronized (newTaskLock) {
// synchronize newTaskLock not to reject task execution after acquiring them from taskServer
executor.shutdown();
activeCount = executor.getActiveCount();
newTaskLock.notifyAll();
int maximumActiveTasks;
synchronized (addActiveTaskLock) {
// synchronize addActiveTaskLock not to reject task execution after acquiring them from taskServer
executor.shutdown(); // Since here, no one can increase activeTaskCount.
maximumActiveTasks = activeTaskCount.get(); /// Now get the maximum count.
addActiveTaskLock.notifyAll();
}
if (activeCount > 0) {
logger.info("Waiting for completion of {} running tasks...", activeCount);
if (maximumActiveTasks > 0) {
logger.info("Waiting for completion of {} running tasks...", maximumActiveTasks);
}
if (maximumCompletionWait.isPresent()) {
long seconds = maximumCompletionWait.get().getSeconds();
Expand All @@ -82,13 +103,18 @@ public void run()
{
while (!stop) {
try {
synchronized (newTaskLock) {
synchronized (addActiveTaskLock) {
if (executor.isShutdown()) {
break;
}
int max = Math.min(executor.getMaximumPoolSize() - executor.getActiveCount(), 10);
if (max > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(max, agentId, config.getLockRetentionTime(), 1000);
// Because addActiveTaskLock is locked, no one increases activeTaskCount in this synchronized block. Now get the maximum count.
int maximumActiveTasks = activeTaskCount.get();
// Because the maximum count doesn't increase, here can know that at least N number of threads are idling.
int guaranteedAvaialbleThreads = executor.getMaximumPoolSize() - maximumActiveTasks;
// Acquire at most guaranteedAvaialbleThreads or 10. This guarantees that all tasks start immediately.
int maxAcquire = Math.min(guaranteedAvaialbleThreads, 10);
if (maxAcquire > 0) {
List<TaskRequest> reqs = taskServer.lockSharedAgentTasks(maxAcquire, agentId, config.getLockRetentionTime(), 1000);
for (TaskRequest req : reqs) {
executor.submit(() -> {
try {
Expand All @@ -98,12 +124,16 @@ public void run()
logger.error("Uncaught exception. Task queue will detect this failure and this task will be retried later.", t);
errorReporter.reportUncaughtError(t);
}
finally {
activeTaskCount.decrementAndGet();
}
});
activeTaskCount.incrementAndGet();
}
}
else {
// no executor thread is available. sleep for a while until a task execution finishes
newTaskLock.wait(500);
addActiveTaskLock.wait(500);
}
}
}
Expand Down
62 changes: 62 additions & 0 deletions digdag-tests/src/test/java/acceptance/AgentOverAcquireIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package acceptance;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assume.assumeThat;
import static utils.TestUtils.copyResource;
import static utils.TestUtils.main;
import utils.CommandStatus;

public class AgentOverAcquireIT
{
@Rule
public TemporaryFolder folder = new TemporaryFolder();

private Path projectDir;
private Path config;
private Path outdir;

@Before
public void setUp()
throws Exception
{
projectDir = folder.getRoot().toPath();
config = folder.newFile().toPath();

outdir = projectDir.resolve("outdir");
Files.createDirectories(outdir);
}

@Test
public void testOverAcquire()
throws Exception
{
assumeThat(true, is(false)); // disabled by default to avoid too long execution time.

copyResource("acceptance/over_acquire/over_acquire.dig", projectDir.resolve("over_acquire.dig"));

CommandStatus runStatus = main("run",
"-o", projectDir.toString(),
"--config", config.toString(),
"--project", projectDir.toString(),
"-X", "agent.heartbeat-interval=5",
"-X", "agent.lock-retention-time=20",
"-X", "agent.max-task-threads=5",
"-p", "outdir=" + outdir,
"over_acquire.dig");
assertThat(runStatus.errUtf8(), runStatus.code(), is(0));

for (int i = 0; i < 20; i++) {
String one = new String(Files.readAllBytes(outdir.resolve(Integer.toString(i))), UTF_8).trim();
assertThat(one, is("1"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

+loop:
loop>: 20
_parallel: true
_do:
sh>: sleep 30 && echo 1 >> ${outdir}/${i}