Skip to content

Commit

Permalink
Use worker threads for module extension eval
Browse files Browse the repository at this point in the history
Now all the pieces are in place. Could it really be this simple??

Fixes #22729

PiperOrigin-RevId: 669391274
Change-Id: Idc58d24558e6db1331596ba33320a6f30c066b56
  • Loading branch information
Wyverald committed Sep 3, 2024
1 parent 38d14a6 commit 07f60f6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/util",
"//src/main/java/com/google/devtools/build/lib/util:os",
"//src/main/java/com/google/devtools/build/lib/vfs",
"//src/main/java/com/google/devtools/build/skyframe",
"//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
"//src/main/java/net/starlark/java/eval",
"//src/main/java/net/starlark/java/spelling",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static com.google.common.base.StandardSystemProperty.OS_ARCH;
import static com.google.common.collect.ImmutableSet.toImmutableSet;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.analysis.BlazeDirectories;
Expand Down Expand Up @@ -47,11 +48,14 @@
import com.google.devtools.build.lib.util.OS;
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.skyframe.SkyFunction.Environment;
import com.google.devtools.build.skyframe.WorkerSkyKeyComputeState;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import net.starlark.java.eval.EvalException;
Expand Down Expand Up @@ -218,6 +222,37 @@ public RunModuleExtensionResult run(
ModuleExtensionId extensionId,
RepositoryMapping mainRepositoryMapping)
throws InterruptedException, ExternalDepsException {
// See below (the `catch CancellationException` clause) for why there's a `while` loop here.
while (true) {
var state = env.getState(WorkerSkyKeyComputeState<RunModuleExtensionResult>::new);
try {
return state.startOrContinueWork(
env,
"module-extension-" + extensionId.asTargetString(),
(workerEnv) ->
runInternal(
workerEnv, usagesValue, starlarkSemantics, extensionId, mainRepositoryMapping));
} catch (ExecutionException e) {
Throwables.throwIfInstanceOf(e.getCause(), ExternalDepsException.class);
Throwables.throwIfUnchecked(e.getCause());
throw new IllegalStateException(
"unexpected exception type: " + e.getCause().getClass(), e.getCause());
} catch (CancellationException e) {
// This can only happen if the state object was invalidated due to memory pressure, in
// which case we can simply reattempt eval.
}
}
}

@Nullable
private RunModuleExtensionResult runInternal(
Environment env,
SingleExtensionUsagesValue usagesValue,
StarlarkSemantics starlarkSemantics,
ModuleExtensionId extensionId,
RepositoryMapping mainRepositoryMapping)
throws InterruptedException, ExternalDepsException {
env.getListener().post(ModuleExtensionEvaluationProgress.ongoing(extensionId, "starting"));
char separator =
starlarkSemantics.getBool(BuildLanguageOptions.INCOMPATIBLE_USE_PLUS_IN_REPO_NAMES)
? '+'
Expand Down Expand Up @@ -275,6 +310,7 @@ public RunModuleExtensionResult run(
return null;
}
moduleContext.markSuccessful();
env.getListener().post(ModuleExtensionEvaluationProgress.finished(extensionId));
return new RunModuleExtensionResult(
moduleContext.getRecordedFileInputs(),
moduleContext.getRecordedDirentsInputs(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ public SkyValue compute(SkyKey skyKey, Environment env)
}

// Run that extension!
env.getListener().post(ModuleExtensionEvaluationProgress.ongoing(extensionId, "starting"));
RunModuleExtensionResult moduleExtensionResult;
try {
moduleExtensionResult =
Expand All @@ -185,7 +184,6 @@ public SkyValue compute(SkyKey skyKey, Environment env)
if (moduleExtensionResult == null) {
return null;
}
env.getListener().post(ModuleExtensionEvaluationProgress.finished(extensionId));
ImmutableMap<String, RepoSpec> generatedRepoSpecs = moduleExtensionResult.generatedRepoSpecs();
Optional<ModuleExtensionMetadata> moduleExtensionMetadata =
moduleExtensionResult.moduleExtensionMetadata();
Expand Down
33 changes: 33 additions & 0 deletions src/test/py/bazel/bzlmod/bazel_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,39 @@ def testDownload(self):
])
self.RunBazel(['build', '@no_op//:no_op'])

def testNoRestart(self):
self.ScratchFile(
'MODULE.bazel',
[
'data_ext = use_extension("//:ext.bzl", "data_ext")',
'use_repo(data_ext, "no_op")',
],
)
self.ScratchFile('BUILD')
self.ScratchFile('foo.txt', ['abc'])
self.ScratchFile(
'ext.bzl',
[
'def _no_op_impl(ctx):',
' ctx.file("BUILD", "filegroup(name=\\"no_op\\")")',
'no_op = repository_rule(_no_op_impl)',
'def _data_ext_impl(ctx):',
' print("I AM HERE")',
' ctx.watch(Label("//:foo.txt"))',
' no_op(name="no_op")',
'data_ext = module_extension(_data_ext_impl)',
],
)
_, _, stderr = self.RunBazel(['build', '@no_op//:no_op'])
stderr = '\n'.join(stderr)
self.assertIn('I AM HERE', stderr)
self.assertEqual(
stderr.count('I AM HERE'),
1,
'expected "I AM HERE" once, but got %s times:\n%s'
% (stderr.count('I AM HERE'), stderr),
)

def setUpProjectWithLocalRegistryModule(self, dep_name, dep_version):
self.main_registry.generateCcSource(dep_name, dep_version)
self.main_registry.createLocalPathModule(dep_name, dep_version,
Expand Down

0 comments on commit 07f60f6

Please sign in to comment.