Skip to content

Commit

Permalink
Implement proto_common.experimental_should_generate_code.
Browse files Browse the repository at this point in the history
  • Loading branch information
comius authored and copybara-github committed Apr 7, 2022
1 parent 4e2b21b commit cf7ebef
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
import javax.annotation.Nullable;
import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Module;
import net.starlark.java.eval.Sequence;
import net.starlark.java.eval.Starlark;
import net.starlark.java.eval.StarlarkCallable;
import net.starlark.java.eval.StarlarkFunction;
import net.starlark.java.eval.StarlarkList;
import net.starlark.java.eval.StarlarkThread;
import net.starlark.java.eval.Tuple;

/** Utility functions for proto_library and proto aspect implementations. */
public class ProtoCommon {
Expand Down Expand Up @@ -206,4 +208,51 @@ public static void compile(
/* plugin_output */ pluginOutput == null ? Starlark.NONE : pluginOutput),
ImmutableMap.of("experimental_progress_message", progressMessage));
}

public static boolean shouldGenerateCode(
RuleContext ruleContext,
ConfiguredTarget protoTarget,
ProtoLangToolchainProvider protoLangToolchainInfo,
String ruleName)
throws RuleErrorException, InterruptedException {
StarlarkFunction shouldGenerateCode =
(StarlarkFunction)
ruleContext.getStarlarkDefinedBuiltin("proto_common_experimental_should_generate_code");
ruleContext.initStarlarkRuleContext();
return (Boolean)
ruleContext.callStarlarkOrThrowRuleError(
shouldGenerateCode,
ImmutableList.of(
/* proto_library_target */ protoTarget,
/* proto_lang_toolchain_info */ protoLangToolchainInfo,
/* rule_name */ ruleName),
ImmutableMap.of());
}

public static Sequence<Artifact> filterSources(
RuleContext ruleContext,
ConfiguredTarget protoTarget,
ProtoLangToolchainProvider protoLangToolchainInfo)
throws RuleErrorException, InterruptedException {
StarlarkFunction filterSources =
(StarlarkFunction)
ruleContext.getStarlarkDefinedBuiltin("proto_common_experimental_filter_sources");
ruleContext.initStarlarkRuleContext();
try {
return Sequence.cast(
((Tuple)
ruleContext.callStarlarkOrThrowRuleError(
filterSources,
ImmutableList.of(
/* proto_library_target */ protoTarget,
/* proto_lang_toolchain_info */ protoLangToolchainInfo),
ImmutableMap.of()))
.get(0),
Artifact.class,
"included");
} catch (EvalException e) {

throw new RuleErrorException(e.getMessageWithStack());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ public ImmutableList<Artifact> getDirectProtoSources() {
return directProtoSources;
}

@Override
public ImmutableList<ProtoSource> getDirectProtoSourcesForStarlark(StarlarkThread thread)
throws EvalException {
ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
return directSources;
}

/**
* The source root of the current library.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ public Provider getProvider() {
* Returns a list of {@link ProtoSource}s that are already provided by the protobuf runtime (i.e.
* for which {@code <lang>_proto_library} should not generate bindings.
*/
@StarlarkMethod(
name = "provided_proto_sources",
doc = "Proto sources provided by the toolchain.",
structField = true)
public abstract ImmutableList<ProtoSource> providedProtoSources();

@StarlarkMethod(name = "proto_compiler", doc = "Proto compiler.", structField = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ public Artifact getSourceFileForStarlark(StarlarkThread thread) throws EvalExcep
}

/** Returns the original source file. Only for forbidding protos! */
@Deprecated
@StarlarkMethod(name = "original_source_file", documented = false, useStarlarkThread = true)
public Artifact getOriginalSourceFileForStarlark(StarlarkThread thread) throws EvalException {
ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
return originalSourceFile;
}

Artifact getOriginalSourceFile() {
return originalSourceFile;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ interface ProtoInfoProviderApi extends ProviderApi {
structField = true)
ImmutableList<FileT> getDirectProtoSources();

@StarlarkMethod(name = "direct_proto_sources", documented = false, useStarlarkThread = true)
ImmutableList<?> getDirectProtoSourcesForStarlark(StarlarkThread thread) throws EvalException;

@StarlarkMethod(
name = "check_deps_sources",
doc =
Expand Down
2 changes: 2 additions & 0 deletions src/main/starlark/builtins_bzl/common/exports.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ exported_rules = {
exported_to_java = {
"register_compile_and_archive_actions_for_j2objc": compilation_support.register_compile_and_archive_actions_for_j2objc,
"proto_common_compile": proto_common_do_not_use.compile,
"proto_common_experimental_should_generate_code": proto_common_do_not_use.experimental_should_generate_code,
"proto_common_experimental_filter_sources": proto_common_do_not_use.experimental_filter_sources,
"link_multi_arch_static_library": linking_support.link_multi_arch_static_library,
}
73 changes: 73 additions & 0 deletions src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,84 @@ def _compile(
resource_set = resource_set,
)

_BAZEL_TOOLS_PREFIX = "external/bazel_tools/"

def _experimental_filter_sources(proto_library_target, proto_lang_toolchain_info):
proto_info = proto_library_target[_builtins.toplevel.ProtoInfo]
if not proto_info.direct_sources:
return [], []

# Collect a set of provided protos
provided_proto_sources = proto_lang_toolchain_info.provided_proto_sources
provided_paths = {}
for src in provided_proto_sources:
path = src.original_source_file().path

# For listed protos bundled with the Bazel tools repository, their exec paths start
# with external/bazel_tools/. This prefix needs to be removed first, because the protos in
# user repositories will not have that prefix.
if path.startswith(_BAZEL_TOOLS_PREFIX):
provided_paths[path[len(_BAZEL_TOOLS_PREFIX):]] = None
else:
provided_paths[path] = None

# Filter proto files
proto_files = [src.original_source_file() for src in proto_info.direct_proto_sources()]
excluded = []
included = []
for proto_file in proto_files:
if proto_file.path in provided_paths:
excluded.append(proto_file)
else:
included.append(proto_file)
return included, excluded

def _experimental_should_generate_code(
proto_library_target,
proto_lang_toolchain_info,
rule_name):
"""Checks if the code should be generated for the given proto_library.
The code shouldn't be generated only when the toolchain already provides it
to the language through its runtime dependency.
It fails when the proto_library contains mixed proto files, that should and
shouldn't generate code.
Args:
proto_library_target:
(Target) The proto_library to generate the sources for.
Obtained as the `target` parameter from an aspect's implementation.
proto_lang_toolchain_info:
(ProtoLangToolchainInfo) The proto lang toolchain info.
Obtained from a `proto_lang_toolchain` target or constructed ad-hoc.
rule_name: (str) Name of the rule used in the failure message.
Returns:
(bool) True when the code should be generated.
"""
included, excluded = _experimental_filter_sources(proto_library_target, proto_lang_toolchain_info)

if included and excluded:
fail(("The 'srcs' attribute of '%s' contains protos for which '%s' " +
"shouldn't generate code (%s), in addition to protos for which it should (%s).\n" +
"Separate '%s' into 2 proto_library rules.") % (
proto_library_target.label,
rule_name,
", ".join([f.short_path for f in excluded]),
", ".join([f.short_path for f in included]),
proto_library_target.label,
))

return bool(included)

proto_common = struct(
create_proto_compile_action = _create_proto_compile_action,
)

proto_common_do_not_use = struct(
compile = _compile,
experimental_should_generate_code = _experimental_should_generate_code,
experimental_filter_sources = _experimental_filter_sources,
ProtoLangToolchainInfo = _builtins.internal.ProtoLangToolchainInfo,
)
2 changes: 2 additions & 0 deletions src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ java_test(
"//src/main/java/com/google/devtools/build/lib/actions:localhost_capacity",
"//src/main/java/com/google/devtools/build/lib/analysis:analysis_cluster",
"//src/main/java/com/google/devtools/build/lib/analysis:configured_target",
"//src/main/java/com/google/devtools/build/lib/cmdline",
"//src/main/java/com/google/devtools/build/lib/packages",
"//src/main/java/com/google/devtools/build/lib/util:os",
"//src/test/java/com/google/devtools/build/lib/actions/util",
"//src/test/java/com/google/devtools/build/lib/analysis/util",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
import com.google.devtools.build.lib.analysis.ConfiguredTarget;
import com.google.devtools.build.lib.analysis.actions.SpawnAction;
import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.packages.StarlarkInfo;
import com.google.devtools.build.lib.packages.StarlarkProvider;
import com.google.devtools.build.lib.packages.StarlarkProviderIdentifier;
import com.google.devtools.build.lib.packages.util.MockProtoSupport;
import com.google.devtools.build.lib.testutil.TestConstants;
import com.google.devtools.build.lib.util.OS;
Expand All @@ -39,6 +43,11 @@ public class BazelProtoCommonTest extends BuildViewTestCase {
private static final Correspondence<String, String> MATCHES_REGEX =
Correspondence.from((a, b) -> Pattern.matches(b, a), "matches");

private static final StarlarkProviderIdentifier boolProviderId =
StarlarkProviderIdentifier.forKey(
new StarlarkProvider.Key(
Label.parseAbsoluteUnchecked("//foo:should_generate.bzl"), "BoolProvider"));

@Before
public final void setup() throws Exception {
MockProtoSupport.setupWorkspace(scratch);
Expand All @@ -53,6 +62,8 @@ public final void setup() throws Exception {
"cc_library(name = 'runtime', srcs = ['runtime.cc'])",
"filegroup(name = 'descriptors', srcs = ['metadata.proto', 'descriptor.proto'])",
"filegroup(name = 'any', srcs = ['any.proto'])",
"filegroup(name = 'something', srcs = ['something.proto'])",
"proto_library(name = 'mixed', srcs = [':descriptors', ':something'])",
"proto_library(name = 'denied', srcs = [':descriptors', ':any'])");
scratch.file(
"foo/BUILD",
Expand Down Expand Up @@ -115,6 +126,21 @@ public final void setup() throws Exception {
" 'use_resource_set': attr.bool(),",
" 'progress_message': attr.string(),",
" })");

scratch.file(
"foo/should_generate.bzl",
"BoolProvider = provider()",
"def _impl(ctx):",
" result = proto_common_do_not_use.experimental_should_generate_code(",
" ctx.attr.proto_dep,",
" ctx.attr.toolchain[proto_common_do_not_use.ProtoLangToolchainInfo],",
" 'MyRule')",
" return [BoolProvider(value = result)]",
"should_generate_rule = rule(_impl,",
" attrs = {",
" 'proto_dep': attr.label(),",
" 'toolchain': attr.label(default = '//foo:toolchain'),",
" })");
}

/** Verifies basic usage of <code>proto_common.generate_code</code>. */
Expand Down Expand Up @@ -489,4 +515,55 @@ public void generateCode_overrideProgressMessage() throws Exception {
assertThat(spawnAction.getMnemonic()).isEqualTo("MyMnemonic");
assertThat(spawnAction.getProgressMessage()).isEqualTo("My //bar:simple");
}

/** Verifies <code>proto_common.should_generate_code</code> call. */
@Test
public void shouldGenerateCode_basic() throws Exception {
scratch.file(
"bar/BUILD",
TestConstants.LOAD_PROTO_LIBRARY,
"load('//foo:should_generate.bzl', 'should_generate_rule')",
"proto_library(name = 'proto', srcs = ['A.proto'])",
"should_generate_rule(name = 'simple', proto_dep = ':proto')");

ConfiguredTarget target = getConfiguredTarget("//bar:simple");

StarlarkInfo boolProvider = (StarlarkInfo) target.get(boolProviderId);
assertThat(boolProvider.getValue("value", Boolean.class)).isTrue();
}

/** Verifies <code>proto_common.should_generate_code</code> call. */
@Test
public void shouldGenerateCode_dontGenerate() throws Exception {
scratch.file(
"bar/BUILD",
TestConstants.LOAD_PROTO_LIBRARY,
"load('//foo:should_generate.bzl', 'should_generate_rule')",
"should_generate_rule(name = 'simple', proto_dep = '//third_party/x:denied')");

ConfiguredTarget target = getConfiguredTarget("//bar:simple");

StarlarkInfo boolProvider = (StarlarkInfo) target.get(boolProviderId);
assertThat(boolProvider.getValue("value", Boolean.class)).isFalse();
}

/** Verifies <code>proto_common.should_generate_code</code> call. */
@Test
public void shouldGenerateCode_mixed() throws Exception {
scratch.file(
"bar/BUILD",
TestConstants.LOAD_PROTO_LIBRARY,
"load('//foo:should_generate.bzl', 'should_generate_rule')",
"should_generate_rule(name = 'simple', proto_dep = '//third_party/x:mixed')");

reporter.removeHandler(failFastHandler);
getConfiguredTarget("//bar:simple");

assertContainsEvent(
"The 'srcs' attribute of '//third_party/x:mixed' contains protos for which 'MyRule'"
+ " shouldn't generate code (third_party/x/metadata.proto,"
+ " third_party/x/descriptor.proto), in addition to protos for which it should"
+ " (third_party/x/something.proto).\n"
+ "Separate '//third_party/x:mixed' into 2 proto_library rules.");
}
}

0 comments on commit cf7ebef

Please sign in to comment.