Skip to content

Commit

Permalink
Cache generated AccumulatorFactory classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Mar 13, 2022
1 parent 0f1455d commit 081f153
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public WindowFunction createWindowFunction(boolean ignoreNulls, List<Supplier<Ob
return new AggregateWindowFunction(() -> createWindowAccumulator(lambdaProviders), hasRemoveInput);
}

private WindowAccumulator createWindowAccumulator(List<Supplier<Object>> lambdaProviders)
public WindowAccumulator createWindowAccumulator(List<Supplier<Object>> lambdaProviders)
{
try {
return constructor.newInstance(lambdaProviders);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,18 @@
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionNullability;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.WindowAccumulator;
import io.trino.operator.window.AggregationWindowFunctionSupplier;
import io.trino.operator.window.MappedWindowIndex;
import io.trino.operator.window.matcher.ArrayView;
import io.trino.operator.window.pattern.SetEvaluator.SetEvaluatorSupplier;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;

import java.lang.reflect.Constructor;
import java.util.List;
import java.util.function.Supplier;

import static io.trino.operator.aggregation.AccumulatorCompiler.generateWindowAccumulatorClass;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -193,8 +190,7 @@ public static class MatchAggregationInstantiator

public MatchAggregationInstantiator(
BoundSignature boundSignature,
AggregationMetadata aggregationMetadata,
FunctionNullability functionNullability,
AggregationWindowFunctionSupplier aggregationWindowFunctionSupplier,
List<Integer> argumentChannels,
List<Supplier<Object>> lambdaProviders,
SetEvaluatorSupplier setEvaluatorSupplier)
Expand All @@ -203,24 +199,13 @@ public MatchAggregationInstantiator(
this.argumentChannels = requireNonNull(argumentChannels, "argumentChannels is null");
this.setEvaluatorSupplier = requireNonNull(setEvaluatorSupplier, "setEvaluatorSupplier is null");

Constructor<? extends WindowAccumulator> constructor = generateWindowAccumulatorClass(boundSignature, aggregationMetadata, functionNullability);
this.accumulatorFactory = () -> createWindowAccumulator(constructor, lambdaProviders);
this.accumulatorFactory = () -> aggregationWindowFunctionSupplier.createWindowAccumulator(lambdaProviders);
}

public MatchAggregation get(AggregatedMemoryContext memoryContextSupplier)
{
requireNonNull(memoryContextSupplier, "memoryContextSupplier is null");
return new MatchAggregation(boundSignature, accumulatorFactory, argumentChannels, setEvaluatorSupplier.get(), memoryContextSupplier);
}

private static WindowAccumulator createWindowAccumulator(Constructor<? extends WindowAccumulator> constructor, List<Supplier<Object>> lambdaProviders)
{
try {
return constructor.newInstance(lambdaProviders);
}
catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.sql.planner;

import com.google.common.base.VerifyException;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ContiguousSet;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableBiMap;
Expand All @@ -31,6 +32,7 @@
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.collect.cache.NonEvictableCache;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.DynamicFilterConfig;
import io.trino.execution.ExplainAnalyzeContext;
Expand All @@ -41,6 +43,8 @@
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.index.IndexManager;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionId;
import io.trino.metadata.FunctionKind;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
Expand Down Expand Up @@ -256,6 +260,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
Expand All @@ -266,6 +271,7 @@
import java.util.stream.IntStream;

import static com.google.common.base.Functions.forMap;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand All @@ -291,6 +297,8 @@
import static io.trino.SystemSessionProperties.isExchangeCompressionEnabled;
import static io.trino.SystemSessionProperties.isLateMaterializationEnabled;
import static io.trino.SystemSessionProperties.isSpillEnabled;
import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.operator.DistinctLimitOperator.DistinctLimitOperatorFactory;
import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier;
import static io.trino.operator.PipelineExecutionStrategy.GROUPED_EXECUTION;
Expand Down Expand Up @@ -354,6 +362,7 @@
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.HOURS;
import static java.util.stream.Collectors.partitioningBy;
import static java.util.stream.IntStream.range;

Expand Down Expand Up @@ -391,6 +400,13 @@ public class LocalExecutionPlanner
private final ExchangeManagerRegistry exchangeManagerRegistry;
private final PositionsAppenderFactory positionsAppenderFactory = new PositionsAppenderFactory();

private final NonEvictableCache<FunctionKey, AccumulatorFactory> accumulatorFactoryCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, HOURS));
private final NonEvictableCache<FunctionKey, AggregationWindowFunctionSupplier> aggregationWindowFunctionSupplierCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, HOURS));

@Inject
public LocalExecutionPlanner(
PlannerContext plannerContext,
Expand Down Expand Up @@ -1167,11 +1183,13 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext
private WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction)
{
if (resolvedFunction.getFunctionKind() == FunctionKind.AGGREGATE) {
AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(resolvedFunction);
return new AggregationWindowFunctionSupplier(
resolvedFunction.getSignature(),
aggregationMetadata,
resolvedFunction.getFunctionNullability());
return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()), () -> {
AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(resolvedFunction);
return new AggregationWindowFunctionSupplier(
resolvedFunction.getSignature(),
aggregationMetadata,
resolvedFunction.getFunctionNullability());
});
}
return plannerContext.getFunctionManager().getWindowFunctionImplementation(resolvedFunction);
}
Expand Down Expand Up @@ -1510,10 +1528,11 @@ else if (matchNumberSymbols.contains(pointer.getInputSymbol())) {

boolean classifierInvolved = false;

ResolvedFunction resolvedFunction = pointer.getFunction();
AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(pointer.getFunction());

ImmutableList.Builder<Map.Entry<Expression, Type>> builder = ImmutableList.builder();
List<Type> signatureTypes = pointer.getFunction().getSignature().getArgumentTypes();
List<Type> signatureTypes = resolvedFunction.getSignature().getArgumentTypes();
for (int i = 0; i < pointer.getArguments().size(); i++) {
builder.add(new SimpleEntry<>(pointer.getArguments().get(i), signatureTypes.get(i)));
}
Expand All @@ -1526,7 +1545,7 @@ else if (matchNumberSymbols.contains(pointer.getInputSymbol())) {
.map(LambdaExpression.class::cast)
.collect(toImmutableList());

List<FunctionType> functionTypes = pointer.getFunction().getSignature().getArgumentTypes().stream()
List<FunctionType> functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream()
.filter(FunctionType.class::isInstance)
.map(FunctionType.class::cast)
.collect(toImmutableList());
Expand Down Expand Up @@ -1575,10 +1594,16 @@ else if (symbol.equals(matchNumberArgumentSymbol)) {
}
}

AggregationWindowFunctionSupplier aggregationWindowFunctionSupplier = uncheckedCacheGet(
aggregationWindowFunctionSupplierCache,
new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()),
() -> new AggregationWindowFunctionSupplier(
resolvedFunction.getSignature(),
aggregationMetadata,
resolvedFunction.getFunctionNullability()));
matchAggregations.add(new MatchAggregationInstantiator(
pointer.getFunction().getSignature(),
aggregationMetadata,
pointer.getFunction().getFunctionNullability(),
resolvedFunction.getSignature(),
aggregationWindowFunctionSupplier,
valueChannels,
lambdaProviders,
new SetEvaluatorSupplier(pointer.getSetDescriptor(), mapping)));
Expand Down Expand Up @@ -3546,11 +3571,15 @@ private AggregatorFactory buildAggregatorFactory(
}
}

ResolvedFunction resolvedFunction = aggregation.getResolvedFunction();
AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(aggregation.getResolvedFunction());
AccumulatorFactory accumulatorFactory = generateAccumulatorFactory(
aggregation.getResolvedFunction().getSignature(),
aggregationMetadata,
aggregation.getResolvedFunction().getFunctionNullability());
AccumulatorFactory accumulatorFactory = uncheckedCacheGet(
accumulatorFactoryCache,
new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()),
() -> generateAccumulatorFactory(
resolvedFunction.getSignature(),
aggregationMetadata,
resolvedFunction.getFunctionNullability()));

if (aggregation.isDistinct()) {
accumulatorFactory = new DistinctAccumulatorFactory(
Expand Down Expand Up @@ -3600,7 +3629,7 @@ private AggregatorFactory buildAggregatorFactory(
.map(stateDescriptor -> stateDescriptor.getSerializer().getSerializedType())
.collect(toImmutableList());
Type intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes);
Type finalType = aggregation.getResolvedFunction().getSignature().getReturnType();
Type finalType = resolvedFunction.getSignature().getReturnType();

OptionalInt maskChannel = aggregation.getMask().stream()
.mapToInt(value -> source.getLayout().get(value))
Expand All @@ -3610,7 +3639,7 @@ private AggregatorFactory buildAggregatorFactory(
.filter(LambdaExpression.class::isInstance)
.map(LambdaExpression.class::cast)
.collect(toImmutableList());
List<FunctionType> functionTypes = aggregation.getResolvedFunction().getSignature().getArgumentTypes().stream()
List<FunctionType> functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream()
.filter(FunctionType.class::isInstance)
.map(FunctionType.class::cast)
.collect(toImmutableList());
Expand Down Expand Up @@ -4169,4 +4198,55 @@ public boolean isClassifierInvolved()
return classifierInvolved;
}
}

private static class FunctionKey
{
private final FunctionId functionId;
private final BoundSignature boundSignature;

public FunctionKey(FunctionId functionId, BoundSignature boundSignature)
{
this.functionId = requireNonNull(functionId, "functionId is null");
this.boundSignature = requireNonNull(boundSignature, "boundSignature is null");
}

public FunctionId getFunctionId()
{
return functionId;
}

public BoundSignature getBoundSignature()
{
return boundSignature;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionKey that = (FunctionKey) o;
return functionId.equals(that.functionId) &&
boundSignature.equals(that.boundSignature);
}

@Override
public int hashCode()
{
return Objects.hash(functionId, boundSignature);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("functionId", functionId)
.add("boundSignature", boundSignature)
.toString();
}
}
}

0 comments on commit 081f153

Please sign in to comment.