From e9e6978809b0214e336fee05047d5befe4f4e0c3 Mon Sep 17 00:00:00 2001 From: larsrc Date: Thu, 15 Apr 2021 02:26:20 -0700 Subject: [PATCH] Server-side implementation of worker cancellation. RELNOTES: None. PiperOrigin-RevId: 368598866 --- .../devtools/build/lib/worker/WorkerKey.java | 9 +- .../build/lib/worker/WorkerSpawnRunner.java | 29 +++++- .../build/lib/worker/ExampleWorker.java | 37 +++---- .../lib/worker/ExampleWorkerOptions.java | 6 +- .../devtools/build/lib/worker/TestUtils.java | 2 + .../build/lib/worker/WorkerFactoryTest.java | 1 + .../build/lib/worker/WorkerKeyTest.java | 3 + .../worker/WorkerMultiplexerManagerTest.java | 2 + .../lib/worker/WorkerSpawnRunnerTest.java | 99 ++++++++++++++++++- 9 files changed, 158 insertions(+), 30 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerKey.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerKey.java index d04ecaba4f2f0a..491cb3fa1b8c28 100644 --- a/src/main/java/com/google/devtools/build/lib/worker/WorkerKey.java +++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerKey.java @@ -53,6 +53,8 @@ final class WorkerKey { private final boolean isSpeculative; /** A WorkerProxy will be instantiated if true, instantiate a regular Worker if false. */ private final boolean proxied; + /** If true, the workers for this key are able to cancel work requests. */ + private final boolean cancellable; /** * Cached value for the hash of this key, because the value is expensive to calculate * (ImmutableMap and ImmutableList do not cache their hashcodes. @@ -70,6 +72,7 @@ final class WorkerKey { SortedMap workerFilesWithHashes, boolean isSpeculative, boolean proxied, + boolean cancellable, WorkerProtocolFormat protocolFormat) { this.args = Preconditions.checkNotNull(args); this.env = Preconditions.checkNotNull(env); @@ -79,8 +82,8 @@ final class WorkerKey { this.workerFilesWithHashes = Preconditions.checkNotNull(workerFilesWithHashes); this.isSpeculative = isSpeculative; this.proxied = proxied; + this.cancellable = cancellable; this.protocolFormat = protocolFormat; - hash = calculateHashCode(); } @@ -128,6 +131,10 @@ public boolean isMultiplex() { return getProxied() && !isSpeculative; } + public boolean isCancellable() { + return cancellable; + } + /** Returns the format of the worker protocol. */ public WorkerProtocolFormat getProtocolFormat() { return protocolFormat; diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java index d144e476f362dd..2682aab81c563c 100644 --- a/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java +++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java @@ -77,8 +77,6 @@ final class WorkerSpawnRunner implements SpawnRunner { public static final String REASON_NO_FLAGFILE = "because the command-line arguments do not contain at least one @flagfile or --flagfile="; public static final String REASON_NO_TOOLS = "because the action has no tools"; - public static final String REASON_NO_EXECUTION_INFO = - "because the action's execution info does not contain 'supports-workers=1'"; /** Pattern for @flagfile.txt and --flagfile=flagfile.txt */ private static final Pattern FLAG_FILE_PATTERN = Pattern.compile("(?:@|--?flagfile=)(.+)"); @@ -205,6 +203,7 @@ public SpawnResult exec(Spawn spawn, SpawnExecutionContext context) workerFiles, context.speculating(), multiplex && Spawns.supportsMultiplexWorkers(spawn), + Spawns.supportsWorkerCancellation(spawn), protocolFormat); SpawnMetrics.Builder spawnMetrics = @@ -458,7 +457,11 @@ WorkResponse execInWorker( try { response = worker.getResponse(request.getRequestId()); } catch (InterruptedException e) { - finishWorkAsync(key, worker, request); + finishWorkAsync( + key, + worker, + request, + workerOptions.workerCancellation && Spawns.supportsWorkerCancellation(spawn)); worker = null; throw e; } catch (IOException e) { @@ -480,6 +483,12 @@ WorkResponse execInWorker( throw createEmptyResponseException(worker.getLogFile()); } + if (response.getWasCancelled()) { + throw createUserExecException( + "Received cancel response for " + response.getRequestId() + " without having cancelled", + Code.FINISH_FAILURE); + } + try { Stopwatch processOutputsStopwatch = Stopwatch.createStarted(); context.lockOutputFiles(); @@ -525,12 +534,21 @@ WorkResponse execInWorker( * interrupted. This takes ownership of the worker for purposes of returning it to the worker * pool. */ - private void finishWorkAsync(WorkerKey key, Worker worker, WorkRequest request) { + private void finishWorkAsync( + WorkerKey key, Worker worker, WorkRequest request, boolean canCancel) { Thread reaper = new Thread( () -> { Worker w = worker; try { + if (canCancel) { + WorkRequest cancelRequest = + WorkRequest.newBuilder() + .setRequestId(request.getRequestId()) + .setCancel(true) + .build(); + w.putRequest(cancelRequest); + } w.getResponse(request.getRequestId()); } catch (IOException | InterruptedException e1) { // If this happens, we either can't trust the output of the worker, or we got @@ -549,7 +567,8 @@ private void finishWorkAsync(WorkerKey key, Worker worker, WorkRequest request) workers.returnObject(key, w); } } - }); + }, + "AsyncFinish-Worker-" + worker.workerId); reaper.start(); } diff --git a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java index c79237bab82fd6..ad70a18d262d89 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java +++ b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableSet; import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat; import com.google.devtools.build.lib.worker.ExampleWorkerOptions.ExampleWorkOptions; +import com.google.devtools.build.lib.worker.WorkRequestHandler.WorkerMessageProcessor; import com.google.devtools.build.lib.worker.WorkerProtocol.Input; import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest; import com.google.devtools.common.options.OptionsParser; @@ -42,12 +43,9 @@ import java.util.Map; import java.util.Random; import java.util.UUID; -import java.util.concurrent.Semaphore; import java.util.function.BiFunction; import java.util.regex.Matcher; import java.util.regex.Pattern; -import sun.misc.Signal; -import sun.misc.SignalHandler; /** An example implementation of a worker process that is used for integration tests. */ public final class ExampleWorker { @@ -70,6 +68,7 @@ public final class ExampleWorker { // The options passed to this worker on a per-worker-lifetime basis. static ExampleWorkerOptions workerOptions; + private static WorkerMessageProcessor messageProcessor; private static class InterruptableWorkRequestHandler extends WorkRequestHandler { @@ -118,7 +117,7 @@ public static void main(String[] args) throws Exception { parser.parse(args); workerOptions = parser.getOptions(ExampleWorkerOptions.class); WorkerProtocolFormat protocolFormat = workerOptions.workerProtocol; - WorkRequestHandler.WorkerMessageProcessor messageProcessor = null; + messageProcessor = null; switch (protocolFormat) { case JSON: messageProcessor = @@ -147,21 +146,23 @@ private static int doWork(List args, PrintWriter err) { PrintStream originalStdOut = System.out; PrintStream originalStdErr = System.err; - if (workerOptions.waitForSignal) { - Semaphore signalSem = new Semaphore(0); - Signal.handle( - new Signal("HUP"), - new SignalHandler() { - @Override - public void handle(Signal sig) { - signalSem.release(); - } - }); + if (workerOptions.waitForCancel) { try { - signalSem.acquire(); - } catch (InterruptedException e) { - System.out.println("Interrupted while waiting for signal"); - e.printStackTrace(); + WorkRequest workRequest = messageProcessor.readWorkRequest(); + if (workRequest.getRequestId() != currentRequest.getRequestId()) { + System.err.format( + "Got cancel request for %d while expecting cancel request for %d%n", + workRequest.getRequestId(), currentRequest.getRequestId()); + return 1; + } + if (!workRequest.getCancel()) { + System.err.format( + "Got non-cancel request for %d while expecting cancel request%n", + workRequest.getRequestId()); + return 1; + } + } catch (IOException e) { + throw new RuntimeException("Exception while waiting for cancel request", e); } } try (PrintStream ps = new PrintStream(baos)) { diff --git a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorkerOptions.java b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorkerOptions.java index 440717916a3fd4..0c6310892c0d57 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorkerOptions.java +++ b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorkerOptions.java @@ -136,12 +136,12 @@ public static class ExampleWorkOptions extends OptionsBase { public boolean hardPoison; @Option( - name = "wait_for_signal", + name = "wait_for_cancel", documentationCategory = OptionDocumentationCategory.UNCATEGORIZED, effectTags = {OptionEffectTag.NO_OP}, defaultValue = "false", - help = "Don't send a response until receiving a SIGXXXX.") - public boolean waitForSignal; + help = "Don't send a response until receiving a cancel request.") + public boolean waitForCancel; /** Enum converter for --worker_protocol. */ public static class WorkerProtocolEnumConverter diff --git a/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java b/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java index aa4da66b10f827..aa34f2c45c248c 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java +++ b/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java @@ -45,6 +45,7 @@ static WorkerKey createWorkerKey( /* workerFilesWithHashes= */ ImmutableSortedMap.of(), /* mustBeSandboxed= */ false, /* proxied= */ proxied, + /* cancellable= */ false, WorkerProtocolFormat.PROTO); } @@ -58,6 +59,7 @@ static WorkerKey createWorkerKey(WorkerProtocolFormat protocolFormat, FileSystem /* workerFilesWithHashes= */ ImmutableSortedMap.of(), /* mustBeSandboxed= */ true, /* proxied= */ true, + /* cancellable= */ false, protocolFormat); } diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerFactoryTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerFactoryTest.java index 0d69cf325afa55..fa373493818b4d 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/WorkerFactoryTest.java +++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerFactoryTest.java @@ -58,6 +58,7 @@ protected WorkerKey createWorkerKey(boolean mustBeSandboxed, boolean proxied, St /* workerFilesWithHashes= */ ImmutableSortedMap.of(), /* mustBeSandboxed= */ mustBeSandboxed, /* proxied= */ proxied, + /* cancellable= */ false, WorkerProtocolFormat.PROTO); } diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerKeyTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerKeyTest.java index e8d5aa6a0be816..8c7e2df4c66a29 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/WorkerKeyTest.java +++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerKeyTest.java @@ -43,6 +43,7 @@ private WorkerKey makeWorkerKey(boolean multiplex, boolean dynamic) { /* workerFilesWithHashes= */ ImmutableSortedMap.of(), /* isSpeculative= */ dynamic, /* proxied= */ multiplex, + /* cancellable=*/ false, WorkerProtocolFormat.PROTO); } @@ -90,6 +91,7 @@ public void testWorkerKeyEquality() { workerKey.getWorkerFilesWithHashes(), workerKey.isSpeculative(), workerKey.getProxied(), + workerKey.isCancellable(), workerKey.getProtocolFormat()); assertThat(workerKey).isEqualTo(workerKeyWithSameFields); } @@ -107,6 +109,7 @@ public void testWorkerKeyInequality_protocol() { workerKey.getWorkerFilesWithHashes(), workerKey.isSpeculative(), workerKey.getProxied(), + workerKey.isCancellable(), WorkerProtocolFormat.JSON); assertThat(workerKey).isNotEqualTo(workerKeyWithDifferentProtocol); } diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerManagerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerManagerTest.java index 8e9c027d0dcc5f..760146ca25d0f2 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerManagerTest.java +++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerManagerTest.java @@ -59,6 +59,7 @@ public void instanceCreationRemovalTest() throws Exception { ImmutableSortedMap.of(), false, false, + /* cancellable= */ false, WorkerProtocolFormat.PROTO); WorkerMultiplexer wm1 = WorkerMultiplexerManager.getInstance(workerKey1, logFile); @@ -77,6 +78,7 @@ public void instanceCreationRemovalTest() throws Exception { ImmutableSortedMap.of(), false, false, + /* cancellable= */ false, WorkerProtocolFormat.PROTO); WorkerMultiplexer wm2 = WorkerMultiplexerManager.getInstance(workerKey2, logFile); diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java index adf995c45424af..8991d27760758d 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java +++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java @@ -17,6 +17,7 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.devtools.build.lib.worker.TestUtils.createWorkerKey; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -25,6 +26,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.devtools.build.lib.actions.ExecException; +import com.google.devtools.build.lib.actions.ExecutionRequirements; import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat; import com.google.devtools.build.lib.actions.MetadataProvider; import com.google.devtools.build.lib.actions.ResourceManager; @@ -45,14 +47,17 @@ import com.google.devtools.build.lib.vfs.FileSystemUtils; import com.google.devtools.build.lib.vfs.Path; import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem; +import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest; import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse; import java.io.IOException; +import java.util.concurrent.Semaphore; import org.apache.commons.pool2.PooledObject; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -112,7 +117,6 @@ public void testExecInWorker_happyPath() throws ExecException, InterruptedExcept new WorkerOptions()); WorkerKey key = createWorkerKey(fs, "mnem", false); Path logFile = fs.getPath("/worker.log"); - when(worker.getLogFile()).thenReturn(logFile); when(worker.getResponse(0)) .thenReturn(WorkResponse.newBuilder().setExitCode(0).setOutput("out").build()); WorkResponse response = @@ -134,12 +138,102 @@ public void testExecInWorker_happyPath() throws ExecException, InterruptedExcept verify(context, times(1)).report(ProgressStatus.EXECUTING, "worker"); } + @Test + public void testExecInWorker_finishesAsyncOnInterrupt() throws InterruptedException, IOException { + WorkerSpawnRunner runner = + new WorkerSpawnRunner( + new SandboxHelpers(false), + fs.getPath("/execRoot"), + createWorkerPool(), + /* multiplex */ false, + reporter, + localEnvProvider, + /* binTools */ null, + resourceManager, + /* runfilesTreeUpdater=*/ null, + new WorkerOptions()); + WorkerKey key = createWorkerKey(fs, "mnem", false); + Path logFile = fs.getPath("/worker.log"); + when(worker.getResponse(anyInt())) + .thenThrow(new InterruptedException()) + .thenReturn(WorkResponse.newBuilder().setRequestId(2).build()); + assertThrows( + InterruptedException.class, + () -> + runner.execInWorker( + spawn, + key, + context, + new SandboxInputs(ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of()), + SandboxOutputs.create(ImmutableSet.of(), ImmutableSet.of()), + ImmutableList.of(), + inputFileCache, + spawnMetrics)); + assertThat(logFile.exists()).isFalse(); + verify(context, times(1)).report(ProgressStatus.EXECUTING, "worker"); + verify(worker, times(1)).putRequest(WorkRequest.newBuilder().setRequestId(0).build()); + } + + @Test + public void testExecInWorker_sendsCancelMessageOnInterrupt() + throws ExecException, InterruptedException, IOException { + WorkerOptions workerOptions = new WorkerOptions(); + workerOptions.workerCancellation = true; + when(spawn.getExecutionInfo()) + .thenReturn(ImmutableMap.of(ExecutionRequirements.SUPPORTS_WORKER_CANCELLATION, "1")); + WorkerSpawnRunner runner = + new WorkerSpawnRunner( + new SandboxHelpers(false), + fs.getPath("/execRoot"), + createWorkerPool(), + /* multiplex */ false, + reporter, + localEnvProvider, + /* binTools */ null, + resourceManager, + /* runfilesTreeUpdater=*/ null, + workerOptions); + WorkerKey key = createWorkerKey(fs, "mnem", false); + Path logFile = fs.getPath("/worker.log"); + Semaphore secondResponseRequested = new Semaphore(0); + when(worker.getResponse(anyInt())) + .thenThrow(new InterruptedException()) + .thenAnswer( + invocation -> { + secondResponseRequested.release(); + return WorkResponse.newBuilder() + .setRequestId(invocation.getArgument(0)) + .setWasCancelled(true) + .build(); + }); + assertThrows( + InterruptedException.class, + () -> + runner.execInWorker( + spawn, + key, + context, + new SandboxInputs(ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of()), + SandboxOutputs.create(ImmutableSet.of(), ImmutableSet.of()), + ImmutableList.of(), + inputFileCache, + spawnMetrics)); + secondResponseRequested.acquire(); + assertThat(logFile.exists()).isFalse(); + verify(context, times(1)).report(ProgressStatus.EXECUTING, "worker"); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(WorkRequest.class); + verify(worker, times(2)).putRequest(argumentCaptor.capture()); + assertThat(argumentCaptor.getAllValues().get(0)) + .isEqualTo(WorkRequest.newBuilder().setRequestId(0).build()); + assertThat(argumentCaptor.getAllValues().get(1)) + .isEqualTo(WorkRequest.newBuilder().setRequestId(0).setCancel(true).build()); + } + @Test public void testExecInWorker_noMultiplexWithDynamic() throws ExecException, InterruptedException, IOException { WorkerOptions workerOptions = new WorkerOptions(); workerOptions.workerMultiplex = true; - when(context.speculating()).thenReturn(true); WorkerSpawnRunner runner = new WorkerSpawnRunner( new SandboxHelpers(false), @@ -155,7 +249,6 @@ public void testExecInWorker_noMultiplexWithDynamic() // This worker key just so happens to be multiplex and require sandboxing. WorkerKey key = createWorkerKey(WorkerProtocolFormat.JSON, fs); Path logFile = fs.getPath("/worker.log"); - when(worker.getLogFile()).thenReturn(logFile); when(worker.getResponse(0)) .thenReturn( WorkResponse.newBuilder().setExitCode(0).setRequestId(0).setOutput("out").build());