Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache accumulator factory #11358

Merged
merged 3 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask;
import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.metadata.NodeState.ACTIVE;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -181,12 +181,6 @@ private NodeMap createNodeMap(Optional<CatalogName> catalogName)
private boolean markInaccessibleNode(InternalNode node)
{
Object marker = new Object();
try {
return inaccessibleNodeLogCache.get(node, () -> marker) == marker;
}
catch (ExecutionException e) {
// impossible
throw new RuntimeException(e);
}
return uncheckedCacheGet(inaccessibleNodeLogCache, node, () -> marker) == marker;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
import java.net.UnknownHostException;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask;
import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.metadata.NodeState.ACTIVE;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -164,12 +164,6 @@ private NodeMap createNodeMap(Optional<CatalogName> catalogName)
private boolean markInaccessibleNode(InternalNode node)
{
Object marker = new Object();
try {
return inaccessibleNodeLogCache.get(node, () -> marker) == marker;
}
catch (ExecutionException e) {
// impossible
throw new RuntimeException(e);
}
return uncheckedCacheGet(inaccessibleNodeLogCache, node, () -> marker) == marker;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
import java.lang.invoke.MethodType;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.primitives.Primitives.wrap;
import static io.trino.client.NodeVersion.UNKNOWN;
import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
Expand Down Expand Up @@ -77,9 +77,9 @@ public FunctionManager(GlobalFunctionCatalog globalFunctionCatalog)
public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
{
try {
return specializedScalarCache.get(new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionInvokerInternal(resolvedFunction, invocationConvention));
return uncheckedCacheGet(specializedScalarCache, new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionInvokerInternal(resolvedFunction, invocationConvention));
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
Expand All @@ -100,9 +100,9 @@ private FunctionInvoker getScalarFunctionInvokerInternal(ResolvedFunction resolv
public AggregationMetadata getAggregateFunctionImplementation(ResolvedFunction resolvedFunction)
{
try {
return specializedAggregationCache.get(new FunctionKey(resolvedFunction), () -> getAggregateFunctionImplementationInternal(resolvedFunction));
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregateFunctionImplementationInternal(resolvedFunction));
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
Expand All @@ -120,9 +120,9 @@ private AggregationMetadata getAggregateFunctionImplementationInternal(ResolvedF
public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction)
{
try {
return specializedWindowCache.get(new FunctionKey(resolvedFunction), () -> getWindowFunctionImplementationInternal(resolvedFunction));
return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionImplementationInternal(resolvedFunction));
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.HOURS;
Expand Down Expand Up @@ -117,11 +117,12 @@ public FunctionInvoker getScalarFunctionInvoker(
{
ScalarFunctionImplementation scalarFunctionImplementation;
try {
scalarFunctionImplementation = specializedScalarCache.get(
scalarFunctionImplementation = uncheckedCacheGet(
specializedScalarCache,
new FunctionKey(functionId, boundSignature),
() -> specializeScalarFunction(functionId, boundSignature, functionDependencies));
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
Expand All @@ -138,9 +139,9 @@ private ScalarFunctionImplementation specializeScalarFunction(FunctionId functio
public AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
{
try {
return specializedAggregationCache.get(new FunctionKey(functionId, boundSignature), () -> specializedAggregation(functionId, boundSignature, functionDependencies));
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(functionId, boundSignature), () -> specializedAggregation(functionId, boundSignature, functionDependencies));
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
Expand All @@ -156,9 +157,9 @@ private AggregationMetadata specializedAggregation(FunctionId functionId, BoundS
public WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
{
try {
return specializedWindowCache.get(new FunctionKey(functionId, boundSignature), () -> specializeWindow(functionId, boundSignature, functionDependencies));
return uncheckedCacheGet(specializedWindowCache, new FunctionKey(functionId, boundSignature), () -> specializeWindow(functionId, boundSignature, functionDependencies));
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -131,6 +130,7 @@
import static io.airlift.concurrent.MoreFutures.toListenableFuture;
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.client.NodeVersion.UNKNOWN;
import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.metadata.FunctionKind.AGGREGATE;
import static io.trino.metadata.QualifiedObjectName.convertFromSchemaTableName;
Expand Down Expand Up @@ -2027,14 +2027,11 @@ public ResolvedFunction resolveOperator(Session session, OperatorType operatorTy
{
try {
// todo we should not be caching functions across session
return operatorCache.get(new OperatorCacheKey(operatorType, argumentTypes), () -> {
return uncheckedCacheGet(operatorCache, new OperatorCacheKey(operatorType, argumentTypes), () -> {
String name = mangleOperatorName(operatorType);
return resolvedFunctionInternal(session, QualifiedName.of(name), fromTypes(argumentTypes));
});
}
catch (ExecutionException e) {
throw new UncheckedExecutionException(e);
}
catch (UncheckedExecutionException e) {
if (e.getCause() instanceof TrinoException) {
TrinoException cause = (TrinoException) e.getCause();
Expand All @@ -2059,7 +2056,7 @@ public ResolvedFunction getCoercion(Session session, OperatorType operatorType,
checkArgument(operatorType == OperatorType.CAST || operatorType == OperatorType.SATURATED_FLOOR_CAST);
try {
// todo we should not be caching functions across session
return coercionCache.get(new CoercionCacheKey(operatorType, fromType, toType), () -> {
return uncheckedCacheGet(coercionCache, new CoercionCacheKey(operatorType, fromType, toType), () -> {
String name = mangleOperatorName(operatorType);
FunctionBinding functionBinding = functionResolver.resolveCoercion(
session,
Expand All @@ -2068,9 +2065,6 @@ public ResolvedFunction getCoercion(Session session, OperatorType operatorType,
return resolve(session, functionBinding);
});
}
catch (ExecutionException e) {
throw new UncheckedExecutionException(e);
}
catch (UncheckedExecutionException e) {
if (e.getCause() instanceof TrinoException) {
TrinoException cause = (TrinoException) e.getCause();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
Expand Down Expand Up @@ -177,9 +177,9 @@ public Type getType(TypeSignature signature)
Type type = types.get(signature);
if (type == null) {
try {
return parametricTypeCache.get(signature, () -> instantiateParametricType(signature));
return uncheckedCacheGet(parametricTypeCache, signature, () -> instantiateParametricType(signature));
}
catch (ExecutionException | UncheckedExecutionException e) {
catch (UncheckedExecutionException e) {
throwIfUnchecked(e.getCause());
throw new RuntimeException(e.getCause());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ private AccumulatorCompiler() {}
public static AccumulatorFactory generateAccumulatorFactory(
BoundSignature boundSignature,
AggregationMetadata metadata,
FunctionNullability functionNullability,
List<Supplier<Object>> lambdaProviders)
FunctionNullability functionNullability)
{
// change types used in Aggregation methods to types used in the core Trino engine to simplify code generation
metadata = normalizeAggregationMethods(metadata);
Expand All @@ -115,7 +114,7 @@ public static AccumulatorFactory generateAccumulatorFactory(
return new CompiledAccumulatorFactory(
accumulatorConstructor,
groupedAccumulatorConstructor,
lambdaProviders);
metadata.getLambdaInterfaces());
}

private static <T> Constructor<? extends T> generateAccumulatorClass(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
*/
package io.trino.operator.aggregation;

import java.util.List;
import java.util.function.Supplier;

public interface AccumulatorFactory
{
Accumulator createAccumulator();
List<Class<?>> getLambdaInterfaces();

Accumulator createAccumulator(List<Supplier<Object>> lambdaProviders);

Accumulator createIntermediateAccumulator();
Accumulator createIntermediateAccumulator(List<Supplier<Object>> lambdaProviders);

GroupedAccumulator createGroupedAccumulator();
GroupedAccumulator createGroupedAccumulator(List<Supplier<Object>> lambdaProviders);

GroupedAccumulator createGroupedIntermediateAccumulator();
GroupedAccumulator createGroupedIntermediateAccumulator(List<Supplier<Object>> lambdaProviders);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.List;
import java.util.OptionalInt;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
Expand All @@ -32,6 +33,7 @@ public class AggregatorFactory
private final List<Integer> inputChannels;
private final OptionalInt maskChannel;
private final boolean spillable;
private final List<Supplier<Object>> lambdaProviders;

public AggregatorFactory(
AccumulatorFactory accumulatorFactory,
Expand All @@ -40,7 +42,8 @@ public AggregatorFactory(
Type finalType,
List<Integer> inputChannels,
OptionalInt maskChannel,
boolean spillable)
boolean spillable,
List<Supplier<Object>> lambdaProviders)
{
this.accumulatorFactory = requireNonNull(accumulatorFactory, "accumulatorFactory is null");
this.step = requireNonNull(step, "step is null");
Expand All @@ -49,6 +52,7 @@ public AggregatorFactory(
this.inputChannels = ImmutableList.copyOf(requireNonNull(inputChannels, "inputChannels is null"));
this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
this.spillable = spillable;
this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null"));

checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
}
Expand All @@ -57,10 +61,10 @@ public Aggregator createAggregator()
{
Accumulator accumulator;
if (step.isInputRaw()) {
accumulator = accumulatorFactory.createAccumulator();
accumulator = accumulatorFactory.createAccumulator(lambdaProviders);
}
else {
accumulator = accumulatorFactory.createIntermediateAccumulator();
accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders);
}
return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel);
}
Expand All @@ -69,10 +73,10 @@ public GroupedAggregator createGroupedAggregator()
{
GroupedAccumulator accumulator;
if (step.isInputRaw()) {
accumulator = accumulatorFactory.createGroupedAccumulator();
accumulator = accumulatorFactory.createGroupedAccumulator(lambdaProviders);
}
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator();
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel);
}
Expand All @@ -81,10 +85,10 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan
{
GroupedAccumulator accumulator;
if (step.isInputRaw()) {
accumulator = accumulatorFactory.createGroupedAccumulator();
accumulator = accumulatorFactory.createGroupedAccumulator(lambdaProviders);
}
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator();
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel);
}
Expand Down
Loading