Skip to content

Commit

Permalink
Avoid using futures to collect DFs on workers
Browse files Browse the repository at this point in the history
For collecting DFs for fault tolerant execution it is necessary
to update the final version of dynamic filters in DynamicFiltersCollector
before completion of the task. This ensures that the coordinator will be
aware about the need to fetch dynamic filters from a worker even after
the task has finished running.
  • Loading branch information
raunaqmorarka authored and arhimondr committed Jun 16, 2022
1 parent 7823e01 commit 8935a09
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,28 @@
package io.trino.sql.planner;

import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.Type;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;

import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

Expand All @@ -51,71 +49,106 @@ public class LocalDynamicFilterConsumer
// Mapping from dynamic filter ID to its build channel type.
private final Map<DynamicFilterId, Type> filterBuildTypes;

private final SettableFuture<TupleDomain<DynamicFilterId>> resultFuture;
private final List<Consumer<Map<DynamicFilterId, Domain>>> collectors;

// Number of build-side partitions to be collected, must be provided by setPartitionCount
@GuardedBy("this")
private int expectedPartitionCount = PARTITION_COUNT_INITIAL_VALUE;

@GuardedBy("this")
private boolean collected;

// The resulting predicates from each build-side partition.
@Nullable
@GuardedBy("this")
private final List<TupleDomain<DynamicFilterId>> partitions;
private List<TupleDomain<DynamicFilterId>> partitions;

public LocalDynamicFilterConsumer(Map<DynamicFilterId, Integer> buildChannels, Map<DynamicFilterId, Type> filterBuildTypes)
public LocalDynamicFilterConsumer(
Map<DynamicFilterId, Integer> buildChannels,
Map<DynamicFilterId, Type> filterBuildTypes,
List<Consumer<Map<DynamicFilterId, Domain>>> collectors)
{
this.buildChannels = requireNonNull(buildChannels, "buildChannels is null");
this.filterBuildTypes = requireNonNull(filterBuildTypes, "filterBuildTypes is null");
verify(buildChannels.keySet().equals(filterBuildTypes.keySet()), "filterBuildTypes and buildChannels must have same keys");

this.resultFuture = SettableFuture.create();
requireNonNull(collectors, "collectors is null");
checkArgument(!collectors.isEmpty(), "collectors is empty");
this.collectors = collectors;
this.partitions = new ArrayList<>();
}

public ListenableFuture<Map<DynamicFilterId, Domain>> getDynamicFilterDomains()
{
return Futures.transform(resultFuture, this::convertTupleDomain, directExecutor());
}

@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
{
if (resultFuture.isDone()) {
return;
}
TupleDomain<DynamicFilterId> result = null;
TupleDomain<DynamicFilterId> result;
synchronized (this) {
if (collected) {
return;
}
requireNonNull(partitions, "partitions is null");
// Called concurrently by each DynamicFilterSourceOperator instance (when collection is over).
verify(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE || partitions.size() < expectedPartitionCount);
// NOTE: may result in a bit more relaxed constraint if there are multiple columns and multiple rows.
// See the comment at TupleDomain::columnWiseUnion() for more details.
partitions.add(tupleDomain);
if (partitions.size() == expectedPartitionCount || tupleDomain.isAll()) {
if (tupleDomain.isAll()) {
result = tupleDomain;
}
else if (partitions.size() == expectedPartitionCount) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
if (partitions.isEmpty()) {
result = TupleDomain.none();
}
else {
result = TupleDomain.columnWiseUnion(partitions);
}
}
else {
return;
}
collected = true;
partitions = null;
}

if (result != null) {
resultFuture.set(result);
}
notifyConsumers(result);
}

@Override
public void setPartitionCount(int partitionCount)
{
TupleDomain<DynamicFilterId> result = null;
TupleDomain<DynamicFilterId> result;
synchronized (this) {
if (collected) {
return;
}
checkState(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE, "setPartitionCount should be called only once");
requireNonNull(partitions, "partitions is null");
expectedPartitionCount = partitionCount;
if (partitions.size() == expectedPartitionCount) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
if (partitions.isEmpty()) {
result = TupleDomain.none();
}
else {
result = TupleDomain.columnWiseUnion(partitions);
}
collected = true;
partitions = null;
}
else {
return;
}
}

if (result != null) {
resultFuture.set(result);
}
notifyConsumers(result);
}

private void notifyConsumers(TupleDomain<DynamicFilterId> result)
{
requireNonNull(result, "result is null");
Map<DynamicFilterId, Domain> dynamicFilterDomains = convertTupleDomain(result);
collectors.forEach(consumer -> consumer.accept(dynamicFilterDomains));
}

private Map<DynamicFilterId, Domain> convertTupleDomain(TupleDomain<DynamicFilterId> result)
Expand All @@ -135,7 +168,8 @@ private Map<DynamicFilterId, Domain> convertTupleDomain(TupleDomain<DynamicFilte
public static LocalDynamicFilterConsumer create(
JoinNode planNode,
List<Type> buildSourceTypes,
Set<DynamicFilterId> collectedFilters)
Set<DynamicFilterId> collectedFilters,
List<Consumer<Map<DynamicFilterId, Domain>>> collectors)
{
checkArgument(!planNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty.");
checkArgument(!collectedFilters.isEmpty(), "Collected dynamic filters set is empty");
Expand All @@ -159,7 +193,7 @@ public static LocalDynamicFilterConsumer create(
.collect(toImmutableMap(
Map.Entry::getKey,
entry -> buildSourceTypes.get(entry.getValue())));
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes);
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes, collectors);
}

public Map<DynamicFilterId, Integer> getBuildChannels()
Expand All @@ -168,11 +202,12 @@ public Map<DynamicFilterId, Integer> getBuildChannels()
}

@Override
public String toString()
public synchronized String toString()
{
return toStringHelper(this)
.add("buildChannels", buildChannels)
.add("expectedPartitionCount", expectedPartitionCount)
.add("collected", collected)
.add("partitions", partitions)
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.google.common.collect.Multimap;
import com.google.common.collect.SetMultimap;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.Session;
Expand Down Expand Up @@ -268,6 +267,7 @@
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -286,7 +286,6 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Range.closedOpen;
import static com.google.common.collect.Sets.difference;
import static io.airlift.concurrent.MoreFutures.addSuccessCallback;
import static io.trino.SystemSessionProperties.getAdaptivePartialAggregationMinRows;
import static io.trino.SystemSessionProperties.getAdaptivePartialAggregationUniqueRowsRatioThreshold;
import static io.trino.SystemSessionProperties.getAggregationOperatorUnspillMemoryLimit;
Expand Down Expand Up @@ -745,9 +744,12 @@ private void registerCoordinatorDynamicFilters(List<DynamicFilters.Descriptor> d
difference(consumedFilterIds, dynamicFiltersCollector.getRegisteredDynamicFilterIds()));
}

private void addCoordinatorDynamicFilters(Map<DynamicFilterId, Domain> dynamicTupleDomain)
private Consumer<Map<DynamicFilterId, Domain>> getCoordinatorDynamicFilterDomainsCollector(Set<DynamicFilterId> coordinatorDynamicFilters)
{
taskContext.updateDomains(dynamicTupleDomain);
return domains -> taskContext.updateDomains(
domains.entrySet().stream()
.filter(entry -> coordinatorDynamicFilters.contains(entry.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)));
}

public Optional<IndexSourceContext> getIndexSourceContext()
Expand Down Expand Up @@ -2913,18 +2915,19 @@ private Optional<LocalDynamicFilterConsumer> createDynamicFilter(
buildSource.getPipelineExecutionStrategy() != GROUPED_EXECUTION,
"Dynamic filtering cannot be used with grouped execution");
log.debug("[Join] Dynamic filters: %s", node.getDynamicFilters());
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), collectedDynamicFilters);
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
ImmutableList.Builder<Consumer<Map<DynamicFilterId, Domain>>> collectors = ImmutableList.builder();
if (!localDynamicFilters.isEmpty()) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
collectors.add(context::addLocalDynamicFilters);
}
if (!coordinatorDynamicFilters.isEmpty()) {
addSuccessCallback(
domainsFuture,
domains -> context.addCoordinatorDynamicFilters(domains.entrySet().stream()
.filter(entry -> coordinatorDynamicFilters.contains(entry.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))));
collectors.add(context.getCoordinatorDynamicFilterDomainsCollector(coordinatorDynamicFilters));
}
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(
node,
buildSource.getTypes(),
collectedDynamicFilters,
collectors.build());

return Optional.of(filterConsumer);
}

Expand Down Expand Up @@ -3077,17 +3080,18 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont
// Add a DynamicFilterSourceOperatorFactory to build operator factories
DynamicFilterId filterId = node.getDynamicFilterId().get();
log.debug("[Semi-join] Dynamic filter: %s", filterId);
LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer(
ImmutableMap.of(filterId, buildChannel),
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)));
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
ImmutableList.Builder<Consumer<Map<DynamicFilterId, Domain>>> collectors = ImmutableList.builder();
if (isLocalDynamicFilter) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
collectors.add(context::addLocalDynamicFilters);
}
if (isCoordinatorDynamicFilter) {
addSuccessCallback(domainsFuture, context::addCoordinatorDynamicFilters);
collectors.add(context.getCoordinatorDynamicFilterDomainsCollector(ImmutableSet.of(filterId)));
}
boolean isReplicatedJoin = isBuildSideReplicated(node);
LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer(
ImmutableMap.of(filterId, buildChannel),
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)),
collectors.build());
buildSource = new PhysicalOperation(
new DynamicFilterSourceOperatorFactory(
operatorId,
Expand Down
Loading

0 comments on commit 8935a09

Please sign in to comment.