Skip to content

Commit

Permalink
Simplify prerequisites query functions in RuleContext
Browse files Browse the repository at this point in the history
- keep only one function to get split prerequisites that returns `ConfiguredTargetAndData` and modify users to get `ConfiguredTarget` from it.
- remove `getPrerequisiteMap()`

PiperOrigin-RevId: 573886458
Change-Id: Ic4dbf5add0ccebd39e4c00492c4fd51d283d8f09
  • Loading branch information
mai93 authored and copybara-github committed Oct 16, 2023
1 parent cadbaa5 commit 137d3f1
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Streams;
import com.google.devtools.build.lib.actions.ActionAnalysisMetadata;
Expand Down Expand Up @@ -795,44 +794,12 @@ public boolean isAttrDefined(String attrName, Type<?> type) {
return attributes().has(attrName, type);
}

/**
* Returns the dependencies through a {@code LABEL_DICT_UNARY} attribute as a map from a string to
* a {@link TransitiveInfoCollection}.
*/
public Map<String, TransitiveInfoCollection> getPrerequisiteMap(String attributeName) {
Preconditions.checkState(attributes().has(attributeName, BuildType.LABEL_DICT_UNARY));

ImmutableMap.Builder<String, TransitiveInfoCollection> result = ImmutableMap.builder();
Map<String, Label> dict = attributes().get(attributeName, BuildType.LABEL_DICT_UNARY);
Map<Label, ConfiguredTarget> labelToDep = new HashMap<>();
for (ConfiguredTargetAndData dep : targetMap.get(attributeName)) {
labelToDep.put(dep.getTargetLabel(), dep.getConfiguredTarget());
}

for (Map.Entry<String, Label> entry : dict.entrySet()) {
result.put(entry.getKey(), Preconditions.checkNotNull(labelToDep.get(entry.getValue())));
}

return result.buildOrThrow();
}

/**
* Returns the prerequisites keyed by their configuration transition keys. If the split transition
* is not active (e.g. split() returned an empty list), the key is an empty Optional.
*/
public Map<Optional<String>, ? extends List<? extends TransitiveInfoCollection>>
getSplitPrerequisites(String attributeName) {
return Maps.transformValues(
getSplitPrerequisiteConfiguredTargetAndTargets(attributeName),
(ctatList) -> Lists.transform(ctatList, ConfiguredTargetAndData::getConfiguredTarget));
}

/**
* Returns the prerequisites keyed by their transition keys. If the split transition is not active
* (e.g. split() returned an empty list), the key is an empty Optional.
*/
public Map<Optional<String>, List<ConfiguredTargetAndData>>
getSplitPrerequisiteConfiguredTargetAndTargets(String attributeName) {
public Map<Optional<String>, List<ConfiguredTargetAndData>> getSplitPrerequisites(
String attributeName) {
checkAttributeIsDependency(attributeName);
// Use an ImmutableListMultimap.Builder here to preserve ordering.
ImmutableListMultimap.Builder<Optional<String>, ConfiguredTargetAndData> result =
Expand Down Expand Up @@ -939,7 +906,7 @@ && attributes().getAttributeDefinition(attributeName).getTransitionFactory().isS
// portion of the split transition.
// Callers should be identified, cleaned up, and this check removed.
Map<Optional<String>, List<ConfiguredTargetAndData>> map =
getSplitPrerequisiteConfiguredTargetAndTargets(attributeName);
getSplitPrerequisites(attributeName);
prerequisiteConfiguredTargets =
map.isEmpty() ? ImmutableList.of() : map.entrySet().iterator().next().getValue();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.devtools.build.lib.analysis.starlark;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.devtools.build.lib.analysis.config.transitions.ConfigurationTransition.PATCH_TRANSITION_KEY;
import static com.google.devtools.build.lib.analysis.starlark.StarlarkRuleClassFunctions.ALLOWLIST_EXTEND_RULE;
Expand Down Expand Up @@ -78,6 +79,7 @@
import com.google.devtools.build.lib.packages.Type.LabelClass;
import com.google.devtools.build.lib.shell.ShellUtils;
import com.google.devtools.build.lib.shell.ShellUtils.TokenizationException;
import com.google.devtools.build.lib.skyframe.ConfiguredTargetAndData;
import com.google.devtools.build.lib.starlarkbuildapi.StarlarkRuleContextApi;
import com.google.devtools.build.lib.starlarkbuildapi.StarlarkSubruleApi;
import com.google.devtools.build.lib.starlarkbuildapi.platform.ToolchainContextApi;
Expand Down Expand Up @@ -501,12 +503,12 @@ private static StructImpl buildSplitAttributeInfo(
if (!attr.getTransitionFactory().isSplit()) {
continue;
}
Map<Optional<String>, ? extends List<? extends TransitiveInfoCollection>> splitPrereqs =
Map<Optional<String>, List<ConfiguredTargetAndData>> splitPrereqs =
ruleContext.getSplitPrerequisites(attr.getName());

Map<Object, Object> splitPrereqsMap = new LinkedHashMap<>();
for (Map.Entry<Optional<String>, ? extends List<? extends TransitiveInfoCollection>>
splitPrereq : splitPrereqs.entrySet()) {
for (Map.Entry<Optional<String>, List<ConfiguredTargetAndData>> splitPrereq :
splitPrereqs.entrySet()) {

// Skip a split with an empty dependency list.
// TODO(jungjw): Figure out exactly which cases trigger this and see if this can be made
Expand All @@ -518,10 +520,14 @@ private static StructImpl buildSplitAttributeInfo(
Object value;
if (attr.getType() == BuildType.LABEL) {
Preconditions.checkState(splitPrereq.getValue().size() == 1);
value = splitPrereq.getValue().get(0);
value = splitPrereq.getValue().get(0).getConfiguredTarget();
} else {
// BuildType.LABEL_LIST
value = StarlarkList.immutableCopyOf(splitPrereq.getValue());
value =
StarlarkList.immutableCopyOf(
splitPrereq.getValue().stream()
.map(ConfiguredTargetAndData::getConfiguredTarget)
.collect(toImmutableList()));
}

if (splitPrereq.getKey().isPresent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import com.google.auto.value.AutoValue;
import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
Expand All @@ -51,7 +50,6 @@
import com.google.devtools.build.lib.analysis.RuleContext;
import com.google.devtools.build.lib.analysis.Runfiles;
import com.google.devtools.build.lib.analysis.RunfilesProvider;
import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
import com.google.devtools.build.lib.analysis.actions.ActionConstructionContext;
import com.google.devtools.build.lib.analysis.actions.CustomCommandLine;
import com.google.devtools.build.lib.analysis.actions.CustomCommandLine.VectorArg;
Expand Down Expand Up @@ -88,6 +86,7 @@
import com.google.devtools.build.lib.rules.java.OneVersionCheckActionBuilder;
import com.google.devtools.build.lib.rules.java.ProguardSpecProvider;
import com.google.devtools.build.lib.server.FailureDetails.FailAction.Code;
import com.google.devtools.build.lib.skyframe.ConfiguredTargetAndData;
import com.google.devtools.build.lib.vfs.PathFragment;
import java.io.Serializable;
import java.util.ArrayList;
Expand Down Expand Up @@ -970,10 +969,10 @@ public static RuleConfiguredTargetBuilder createAndroidBinary(
attr -> !"deps".equals(attr),
validations -> builder.addOutputGroup(OutputGroupInfo.VALIDATION_TRANSITIVE, validations));
boolean filterSplitValidations = false; // propagate validations from first split unfiltered
for (List<? extends TransitiveInfoCollection> deps :
ruleContext.getSplitPrerequisites("deps").values()) {
for (List<ConfiguredTargetAndData> deps : ruleContext.getSplitPrerequisites("deps").values()) {
for (OutputGroupInfo provider :
AnalysisUtils.getProviders(deps, OutputGroupInfo.STARLARK_CONSTRUCTOR)) {
AnalysisUtils.getProviders(
getConfiguredTargets(deps), OutputGroupInfo.STARLARK_CONSTRUCTOR)) {
NestedSet<Artifact> validations = provider.getOutputGroup(OutputGroupInfo.VALIDATION);
if (filterSplitValidations) {
// Filter out Android Lint validations by name: we know these validations are expensive
Expand Down Expand Up @@ -1038,16 +1037,22 @@ public static NestedSet<Artifact> getTransitiveNativeLibs(RuleContext ruleContex
// libraries across multiple architectures, e.g. x86 and armeabi-v7a, and need to be packed
// into the APK.
NestedSetBuilder<Artifact> transitiveNativeLibs = NestedSetBuilder.naiveLinkOrder();
for (Map.Entry<Optional<String>, ? extends List<? extends TransitiveInfoCollection>> entry :
ruleContext.getSplitPrerequisites("deps").entrySet()) {
for (List<ConfiguredTargetAndData> deps : ruleContext.getSplitPrerequisites("deps").values()) {
for (AndroidNativeLibsInfo provider :
AnalysisUtils.getProviders(entry.getValue(), AndroidNativeLibsInfo.PROVIDER)) {
AnalysisUtils.getProviders(getConfiguredTargets(deps), AndroidNativeLibsInfo.PROVIDER)) {
transitiveNativeLibs.addTransitive(provider.getNativeLibs());
}
}
return transitiveNativeLibs.build();
}

private static ImmutableList<ConfiguredTarget> getConfiguredTargets(
List<ConfiguredTargetAndData> prerequisitesList) {
return prerequisitesList.stream()
.map(ConfiguredTargetAndData::getConfiguredTarget)
.collect(toImmutableList());
}

static class Java8LegacyDexOutput {
private final Artifact dex;
private final Artifact map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
// limitations under the License.
package com.google.devtools.build.lib.rules.cpp;

import static com.google.common.collect.ImmutableMap.toImmutableMap;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -43,6 +46,7 @@
* com.google.devtools.build.lib.rules.cpp.CppConfiguration}.
*/
public class CcToolchainSuite implements RuleConfiguredTargetFactory {
private static final String TOOLCHAIN_ATTRIBUTE_NAME = "toolchains";

private static TemplateVariableInfo createMakeVariableProvider(
CcToolchainProvider toolchainProvider, Location location) {
Expand All @@ -68,7 +72,7 @@ public ConfiguredTarget create(RuleContext ruleContext)
String compiler = cppConfiguration.getCompilerFromOptions();
String key = transformedCpu + (Strings.isNullOrEmpty(compiler) ? "" : ("|" + compiler));
Map<String, Label> toolchains =
ruleContext.attributes().get("toolchains", BuildType.LABEL_DICT_UNARY);
ruleContext.attributes().get(TOOLCHAIN_ATTRIBUTE_NAME, BuildType.LABEL_DICT_UNARY);
Label selectedCcToolchain = toolchains.get(key);
CcToolchainProvider ccToolchainProvider;

Expand Down Expand Up @@ -134,6 +138,28 @@ public ConfiguredTarget create(RuleContext ruleContext)
return builder.build();
}

/**
* Returns the toolchains defined through a {@code LABEL_DICT_UNARY} attribute as a map from a
* string to a {@link TransitiveInfoCollection}.
*/
private ImmutableMap<String, TransitiveInfoCollection> getToolchainsMap(RuleContext ruleContext) {
Preconditions.checkState(
ruleContext.attributes().has(TOOLCHAIN_ATTRIBUTE_NAME, BuildType.LABEL_DICT_UNARY));

ImmutableMap.Builder<String, TransitiveInfoCollection> result = ImmutableMap.builder();
Map<String, Label> dict =
ruleContext.attributes().get(TOOLCHAIN_ATTRIBUTE_NAME, BuildType.LABEL_DICT_UNARY);
ImmutableMap<Label, ConfiguredTarget> labelToDep =
ruleContext.getPrerequisiteConfiguredTargets(TOOLCHAIN_ATTRIBUTE_NAME).stream()
.collect(toImmutableMap(dep -> dep.getTargetLabel(), dep -> dep.getConfiguredTarget()));

for (Map.Entry<String, Label> entry : dict.entrySet()) {
result.put(entry.getKey(), Preconditions.checkNotNull(labelToDep.get(entry.getValue())));
}

return result.buildOrThrow();
}

private <T extends HasCcToolchainLabel> T selectCcToolchain(
BuiltinProvider<T> providerType,
RuleContext ruleContext,
Expand All @@ -142,7 +168,7 @@ private <T extends HasCcToolchainLabel> T selectCcToolchain(
Label selectedCcToolchain)
throws RuleErrorException {
T selectedAttributes = null;
for (TransitiveInfoCollection dep : ruleContext.getPrerequisiteMap("toolchains").values()) {
for (TransitiveInfoCollection dep : getToolchainsMap(ruleContext).values()) {
T attributes = dep.get(providerType);
if (attributes != null && attributes.getCcToolchainLabel().equals(selectedCcToolchain)) {
selectedAttributes = attributes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ public static AppleLinkingOutputs linkMultiArchBinary(
UserVariablesExtension userVariablesExtension)
throws InterruptedException, RuleErrorException, EvalException {
Map<Optional<String>, List<ConfiguredTargetAndData>> splitDeps =
ruleContext.getSplitPrerequisiteConfiguredTargetAndTargets("deps");
ruleContext.getSplitPrerequisites("deps");
Map<Optional<String>, List<ConfiguredTargetAndData>> splitToolchains =
ruleContext.getSplitPrerequisiteConfiguredTargetAndTargets(
ObjcRuleClasses.CHILD_CONFIG_ATTR);
ruleContext.getSplitPrerequisites(ObjcRuleClasses.CHILD_CONFIG_ATTR);

Preconditions.checkState(
splitDeps.keySet().isEmpty() || splitDeps.keySet().equals(splitToolchains.keySet()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,7 @@ public StructImpl linkMultiArchStaticLibrary(
ruleContext.getStarlarkDefinedBuiltin("link_multi_arch_static_library");
Dict<String, StructImpl> splitTargetTriplets =
MultiArchBinarySupport.getSplitTargetTripletFromCtads(
ruleContext.getSplitPrerequisiteConfiguredTargetAndTargets(
ObjcRuleClasses.CHILD_CONFIG_ATTR));
ruleContext.getSplitPrerequisites(ObjcRuleClasses.CHILD_CONFIG_ATTR));
return (StructImpl)
ruleContext.callStarlarkOrThrowRuleError(
linkMultiArchLibrary,
Expand Down

0 comments on commit 137d3f1

Please sign in to comment.