Skip to content

Commit

Permalink
Expose ProtoLangToolchainInfo to builtins.
Browse files Browse the repository at this point in the history
This necessitated changing ProtoLangToolchainProvider from native to Starlark type.

PiperOrigin-RevId: 437751442
  • Loading branch information
comius authored and copybara-github committed Mar 28, 2022
1 parent b4587f8 commit 528d067
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@
import static com.google.devtools.build.lib.packages.BuildType.LABEL_LIST;
import static com.google.devtools.build.lib.rules.java.proto.JavaLiteProtoAspect.getProtoToolchainLabel;

import com.google.common.collect.ImmutableList;
import com.google.devtools.build.lib.analysis.BaseRuleClasses;
import com.google.devtools.build.lib.analysis.RuleDefinition;
import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
import com.google.devtools.build.lib.analysis.TransitiveInfoProvider;
import com.google.devtools.build.lib.packages.RuleClass;
import com.google.devtools.build.lib.packages.StarlarkProviderIdentifier;
import com.google.devtools.build.lib.rules.java.JavaConfiguration;
Expand Down Expand Up @@ -58,9 +56,7 @@ public RuleClass build(RuleClass.Builder builder, RuleDefinitionEnvironment envi
.aspect(javaProtoAspect))
.add(
attr(JavaProtoAspectCommon.LITE_PROTO_TOOLCHAIN_ATTR, LABEL)
.mandatoryBuiltinProviders(
ImmutableList.<Class<? extends TransitiveInfoProvider>>of(
ProtoLangToolchainProvider.class))
.mandatoryProviders(ProtoLangToolchainProvider.PROVIDER.id())
.value(getProtoToolchainLabel(DEFAULT_PROTO_TOOLCHAIN_LABEL)))
.advertiseStarlarkProvider(StarlarkProviderIdentifier.forKey(JavaInfo.PROVIDER.getKey()))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ public AspectDefinition getDefinition(AspectParameters params) {
// For android_sdk rules, where we just want to get at aidl runtime deps.
.requireStarlarkProviders(forKey(AndroidSdkProvider.PROVIDER.getKey()))
.requireStarlarkProviders(forKey(ProtoInfo.PROVIDER.getKey()))
.requireProviderSets(
.requireStarlarkProviderSets(
ImmutableList.of(
// For proto_lang_toolchain rules, where we just want to get at their runtime
// deps.
ImmutableSet.of(ProtoLangToolchainProvider.class)))
ImmutableSet.of(ProtoLangToolchainProvider.PROVIDER.id())))
.addToolchainTypes(
ToolchainTypeRequirement.create(
Label.parseAbsoluteUnchecked(toolsRepository + sdkToolchainLabel)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public AspectDefinition getDefinition(AspectParameters aspectParameters) {
.useToolchainTransition(true)
.add(
attr(PROTO_TOOLCHAIN_ATTR, LABEL)
.mandatoryBuiltinProviders(ImmutableList.of(ProtoLangToolchainProvider.class))
.mandatoryProviders(ProtoLangToolchainProvider.PROVIDER.id())
.value(PROTO_TOOLCHAIN_LABEL))
.add(
attr(CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME, LABEL)
Expand Down Expand Up @@ -470,7 +470,7 @@ private void createProtoCompileAction(Collection<Artifact> outputs)
}

private ProtoLangToolchainProvider getProtoToolchainProvider() {
return ruleContext.getPrerequisite(PROTO_TOOLCHAIN_ATTR, ProtoLangToolchainProvider.class);
return ruleContext.getPrerequisite(PROTO_TOOLCHAIN_ATTR, ProtoLangToolchainProvider.PROVIDER);
}

public void addProviders(ConfiguredAspect.Builder builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.google.devtools.build.lib.analysis.PlatformConfiguration;
import com.google.devtools.build.lib.analysis.RuleContext;
import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
import com.google.devtools.build.lib.analysis.TransitiveInfoProvider;
import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.cmdline.RepositoryName;
Expand Down Expand Up @@ -117,9 +116,7 @@ public AspectDefinition getDefinition(AspectParameters aspectParameters) {
ImmutableList.of(StarlarkProviderIdentifier.forKey(JavaInfo.PROVIDER.getKey())))
.add(
attr(JavaProtoAspectCommon.LITE_PROTO_TOOLCHAIN_ATTR, LABEL)
.mandatoryBuiltinProviders(
ImmutableList.<Class<? extends TransitiveInfoProvider>>of(
ProtoLangToolchainProvider.class))
.mandatoryProviders(ProtoLangToolchainProvider.PROVIDER.id())
.value(getProtoToolchainLabel(defaultProtoToolchainLabel)))
.add(
attr(JavaRuleClasses.JAVA_TOOLCHAIN_ATTRIBUTE_NAME, LABEL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,15 @@ public ImmutableList<TransitiveInfoCollection> getProtoRuntimeDeps() {
/** Returns the toolchain that specifies how to generate code from {@code .proto} files. */
public ProtoLangToolchainProvider getProtoToolchainProvider() {
return checkNotNull(
ruleContext.getPrerequisite(protoToolchainAttr, ProtoLangToolchainProvider.class));
ruleContext.getPrerequisite(protoToolchainAttr, ProtoLangToolchainProvider.PROVIDER));
}

/**
* Returns the toolchain that specifies how to generate Java-lite code from {@code .proto} files.
*/
static ProtoLangToolchainProvider getLiteProtoToolchainProvider(RuleContext ruleContext) {
return ruleContext.getPrerequisite(LITE_PROTO_TOOLCHAIN_ATTR, ProtoLangToolchainProvider.class);
return ruleContext.getPrerequisite(
LITE_PROTO_TOOLCHAIN_ATTR, ProtoLangToolchainProvider.PROVIDER);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ private static ProtoLangToolchainProvider getProtoToolchainProvider(
StarlarkRuleContext starlarkRuleContext, String protoToolchainAttr) throws EvalException {
ConfiguredTarget javaliteToolchain =
(ConfiguredTarget) checkNotNull(starlarkRuleContext.getAttr().getValue(protoToolchainAttr));
return checkNotNull(javaliteToolchain.getProvider(ProtoLangToolchainProvider.class));
return checkNotNull(javaliteToolchain.get(ProtoLangToolchainProvider.PROVIDER));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ private ConfiguredAspect proto(ConfiguredTarget base, RuleContext ruleContext)
ImmutableList<Artifact> protoSources = protoInfo.getDirectProtoSources();

ProtoLangToolchainProvider protoToolchain =
ruleContext.getPrerequisite(J2OBJC_PROTO_TOOLCHAIN_ATTR, ProtoLangToolchainProvider.class);
ruleContext.getPrerequisite(
J2OBJC_PROTO_TOOLCHAIN_ATTR, ProtoLangToolchainProvider.PROVIDER);
// Avoid pulling in any generated files from forbidden protos.
ProtoSourceFileExcludeList protoExcludeList =
new ProtoSourceFileExcludeList(ruleContext, protoToolchain.forbiddenProtos());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public ConfiguredTarget create(RuleContext ruleContext)
flag = flag.replace("$(OUT)", "%s");

return new RuleConfiguredTargetBuilder(ruleContext)
.addProvider(
.addStarlarkDeclaredProvider(
ProtoLangToolchainProvider.create(
flag,
ruleContext.attributes().get("plugin_format_flag", Type.STRING),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import com.google.devtools.build.lib.actions.Artifact;
import com.google.devtools.build.lib.analysis.FilesToRunProvider;
import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
import com.google.devtools.build.lib.analysis.TransitiveInfoProvider;
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
import com.google.devtools.build.lib.packages.BuiltinProvider;
import com.google.devtools.build.lib.packages.NativeInfo;
import javax.annotation.Nullable;
import net.starlark.java.annot.StarlarkBuiltin;
import net.starlark.java.annot.StarlarkMethod;
import net.starlark.java.eval.StarlarkList;

Expand All @@ -33,7 +35,23 @@
* rules.
*/
@AutoValue
public abstract class ProtoLangToolchainProvider implements TransitiveInfoProvider {
public abstract class ProtoLangToolchainProvider extends NativeInfo {
public static final String PROVIDER_NAME = "ProtoLangToolchainInfo";
public static final Provider PROVIDER = new Provider();

/** Provider class for {@link ProtoLangToolchainProvider} objects. */
@StarlarkBuiltin(name = "Provider", documented = false, doc = "")
public static class Provider extends BuiltinProvider<ProtoLangToolchainProvider> {
public Provider() {
super(PROVIDER_NAME, ProtoLangToolchainProvider.class);
}
}

@Override
public Provider getProvider() {
return PROVIDER;
}

@StarlarkMethod(
name = "out_replacement_format_flag",
doc = "Format string used when passing output to the plugin used by proto compiler.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public RuleClass build(RuleClass.Builder builder, RuleDefinitionEnvironment envi
.exec()
.value(PROTO_COMPILER))
.requiresConfigurationFragments(ProtoConfiguration.class)
.advertiseProvider(ProtoLangToolchainProvider.class)
.advertiseStarlarkProvider(ProtoLangToolchainProvider.PROVIDER.id())
.removeAttribute("data")
.removeAttribute("deps")
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public void protoToolchain() throws Exception {
update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());

validateProtoLangToolchain(
getConfiguredTarget("//foo:toolchain").getProvider(ProtoLangToolchainProvider.class));
getConfiguredTarget("//foo:toolchain").get(ProtoLangToolchainProvider.PROVIDER));
}

@Test
Expand Down Expand Up @@ -125,7 +125,7 @@ public void protoToolchainBlacklistProtoLibraries() throws Exception {
update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());

validateProtoLangToolchain(
getConfiguredTarget("//foo:toolchain").getProvider(ProtoLangToolchainProvider.class));
getConfiguredTarget("//foo:toolchain").get(ProtoLangToolchainProvider.PROVIDER));
}

@Test
Expand Down Expand Up @@ -156,7 +156,7 @@ public void protoToolchainBlacklistTransitiveProtos() throws Exception {
update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());

validateProtoLangToolchain(
getConfiguredTarget("//foo:toolchain").getProvider(ProtoLangToolchainProvider.class));
getConfiguredTarget("//foo:toolchain").get(ProtoLangToolchainProvider.PROVIDER));
}

@Test
Expand All @@ -172,7 +172,7 @@ public void optionalFieldsAreEmpty() throws Exception {
update(ImmutableList.of("//foo:toolchain"), false, 1, true, new EventBus());

ProtoLangToolchainProvider toolchain =
getConfiguredTarget("//foo:toolchain").getProvider(ProtoLangToolchainProvider.class);
getConfiguredTarget("//foo:toolchain").get(ProtoLangToolchainProvider.PROVIDER);

assertThat(toolchain.pluginExecutable()).isNull();
assertThat(toolchain.runtime()).isNull();
Expand Down

0 comments on commit 528d067

Please sign in to comment.