From e966733a4b8dd19e590db6ba3138e2e39d61e79d Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Fri, 11 Mar 2022 14:34:48 -0800 Subject: [PATCH 1/3] Add CacheUtils.uncheckedCacheGet helper The uncheckedCacheGet method uses a Supplier instead of a Callable and therefore can not throw a checked exception, so a checked is not thrown from uncheckedCacheGet. This simplifies most callers which do not use checked exceptions. --- .../TopologyAwareNodeSelectorFactory.java | 10 ++---- .../scheduler/UniformNodeSelectorFactory.java | 10 ++---- .../io/trino/metadata/FunctionManager.java | 14 ++++---- .../metadata/InternalFunctionBundle.java | 15 ++++---- .../io/trino/metadata/MetadataManager.java | 12 ++----- .../java/io/trino/metadata/TypeRegistry.java | 6 ++-- .../io/trino/type/BlockTypeOperators.java | 7 ++-- .../io/trino/type/TypeOperatorsCache.java | 6 ++-- .../java/io/trino/util/FastutilSetHelper.java | 7 ++-- .../io/trino/collect/cache/CacheUtils.java | 35 +++++++++++++++++++ .../bigquery/BigQueryClientFactory.java | 9 ++--- .../bigquery/ViewMaterializationCache.java | 17 +++------ .../CachingDeltaLakeStatisticsAccess.java | 6 ++-- .../plugin/hive/CachingDirectoryLister.java | 10 ++---- .../java/io/trino/plugin/ml/MLFunctions.java | 10 ++---- 15 files changed, 85 insertions(+), 89 deletions(-) create mode 100644 lib/trino-collect/src/main/java/io/trino/collect/cache/CacheUtils.java diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java index 9f1e6fcd2a9f..5214a0cc5206 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/TopologyAwareNodeSelectorFactory.java @@ -38,13 +38,13 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.metadata.NodeState.ACTIVE; import static java.util.Objects.requireNonNull; @@ -181,12 +181,6 @@ private NodeMap createNodeMap(Optional catalogName) private boolean markInaccessibleNode(InternalNode node) { Object marker = new Object(); - try { - return inaccessibleNodeLogCache.get(node, () -> marker) == marker; - } - catch (ExecutionException e) { - // impossible - throw new RuntimeException(e); - } + return uncheckedCacheGet(inaccessibleNodeLogCache, node, () -> marker) == marker; } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java index e637de5006ec..5605a78beb42 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/UniformNodeSelectorFactory.java @@ -35,13 +35,13 @@ import java.net.UnknownHostException; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.metadata.NodeState.ACTIVE; import static java.util.Objects.requireNonNull; @@ -164,12 +164,6 @@ private NodeMap createNodeMap(Optional catalogName) private boolean markInaccessibleNode(InternalNode node) { Object marker = new Object(); - try { - return inaccessibleNodeLogCache.get(node, () -> marker) == marker; - } - catch (ExecutionException e) { - // impossible - throw new RuntimeException(e); - } + return uncheckedCacheGet(inaccessibleNodeLogCache, node, () -> marker) == marker; } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java index acaef67c3f2d..39d683a8d4a2 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java @@ -34,13 +34,13 @@ import java.lang.invoke.MethodType; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.ExecutionException; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.primitives.Primitives.wrap; import static io.trino.client.NodeVersion.UNKNOWN; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; @@ -77,9 +77,9 @@ public FunctionManager(GlobalFunctionCatalog globalFunctionCatalog) public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) { try { - return specializedScalarCache.get(new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionInvokerInternal(resolvedFunction, invocationConvention)); + return uncheckedCacheGet(specializedScalarCache, new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionInvokerInternal(resolvedFunction, invocationConvention)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); throw new RuntimeException(e.getCause()); } @@ -100,9 +100,9 @@ private FunctionInvoker getScalarFunctionInvokerInternal(ResolvedFunction resolv public AggregationMetadata getAggregateFunctionImplementation(ResolvedFunction resolvedFunction) { try { - return specializedAggregationCache.get(new FunctionKey(resolvedFunction), () -> getAggregateFunctionImplementationInternal(resolvedFunction)); + return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregateFunctionImplementationInternal(resolvedFunction)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); throw new RuntimeException(e.getCause()); } @@ -120,9 +120,9 @@ private AggregationMetadata getAggregateFunctionImplementationInternal(ResolvedF public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction) { try { - return specializedWindowCache.get(new FunctionKey(resolvedFunction), () -> getWindowFunctionImplementationInternal(resolvedFunction)); + return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionImplementationInternal(resolvedFunction)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); throw new RuntimeException(e.getCause()); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java index 5710006fe157..e33638998f0d 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java @@ -35,7 +35,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.ExecutionException; import java.util.function.Function; import static com.google.common.base.MoreObjects.toStringHelper; @@ -43,6 +42,7 @@ import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.HOURS; @@ -117,11 +117,12 @@ public FunctionInvoker getScalarFunctionInvoker( { ScalarFunctionImplementation scalarFunctionImplementation; try { - scalarFunctionImplementation = specializedScalarCache.get( + scalarFunctionImplementation = uncheckedCacheGet( + specializedScalarCache, new FunctionKey(functionId, boundSignature), () -> specializeScalarFunction(functionId, boundSignature, functionDependencies)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); throw new RuntimeException(e.getCause()); } @@ -138,9 +139,9 @@ private ScalarFunctionImplementation specializeScalarFunction(FunctionId functio public AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { try { - return specializedAggregationCache.get(new FunctionKey(functionId, boundSignature), () -> specializedAggregation(functionId, boundSignature, functionDependencies)); + return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(functionId, boundSignature), () -> specializedAggregation(functionId, boundSignature, functionDependencies)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); throw new RuntimeException(e.getCause()); } @@ -156,9 +157,9 @@ private AggregationMetadata specializedAggregation(FunctionId functionId, BoundS public WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { try { - return specializedWindowCache.get(new FunctionKey(functionId, boundSignature), () -> specializeWindow(functionId, boundSignature, functionDependencies)); + return uncheckedCacheGet(specializedWindowCache, new FunctionKey(functionId, boundSignature), () -> specializeWindow(functionId, boundSignature, functionDependencies)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); throw new RuntimeException(e.getCause()); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 52daf8cb85de..c944672fd24c 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -115,7 +115,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ExecutionException; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -131,6 +130,7 @@ import static io.airlift.concurrent.MoreFutures.toListenableFuture; import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.client.NodeVersion.UNKNOWN; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.metadata.FunctionKind.AGGREGATE; import static io.trino.metadata.QualifiedObjectName.convertFromSchemaTableName; @@ -2027,14 +2027,11 @@ public ResolvedFunction resolveOperator(Session session, OperatorType operatorTy { try { // todo we should not be caching functions across session - return operatorCache.get(new OperatorCacheKey(operatorType, argumentTypes), () -> { + return uncheckedCacheGet(operatorCache, new OperatorCacheKey(operatorType, argumentTypes), () -> { String name = mangleOperatorName(operatorType); return resolvedFunctionInternal(session, QualifiedName.of(name), fromTypes(argumentTypes)); }); } - catch (ExecutionException e) { - throw new UncheckedExecutionException(e); - } catch (UncheckedExecutionException e) { if (e.getCause() instanceof TrinoException) { TrinoException cause = (TrinoException) e.getCause(); @@ -2059,7 +2056,7 @@ public ResolvedFunction getCoercion(Session session, OperatorType operatorType, checkArgument(operatorType == OperatorType.CAST || operatorType == OperatorType.SATURATED_FLOOR_CAST); try { // todo we should not be caching functions across session - return coercionCache.get(new CoercionCacheKey(operatorType, fromType, toType), () -> { + return uncheckedCacheGet(coercionCache, new CoercionCacheKey(operatorType, fromType, toType), () -> { String name = mangleOperatorName(operatorType); FunctionBinding functionBinding = functionResolver.resolveCoercion( session, @@ -2068,9 +2065,6 @@ public ResolvedFunction getCoercion(Session session, OperatorType operatorType, return resolve(session, functionBinding); }); } - catch (ExecutionException e) { - throw new UncheckedExecutionException(e); - } catch (UncheckedExecutionException e) { if (e.getCause() instanceof TrinoException) { TrinoException cause = (TrinoException) e.getCause(); diff --git a/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java index 6b4da83c8168..8831054888f9 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java @@ -48,11 +48,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ExecutionException; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfUnchecked; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; @@ -177,9 +177,9 @@ public Type getType(TypeSignature signature) Type type = types.get(signature); if (type == null) { try { - return parametricTypeCache.get(signature, () -> instantiateParametricType(signature)); + return uncheckedCacheGet(parametricTypeCache, signature, () -> instantiateParametricType(signature)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfUnchecked(e.getCause()); throw new RuntimeException(e.getCause()); } diff --git a/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java b/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java index 664497fc91c0..8b1a378a2c0d 100644 --- a/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java +++ b/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java @@ -29,11 +29,11 @@ import java.lang.invoke.MethodHandle; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import static com.google.common.base.Throwables.throwIfUnchecked; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCacheWithWeakInvalidateAll; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -175,12 +175,13 @@ private T getBlockOperator(Type type, Class operatorInterface, Supplier generatedBlockOperator = (GeneratedBlockOperator) generatedBlockOperatorCache.get( + GeneratedBlockOperator generatedBlockOperator = (GeneratedBlockOperator) uncheckedCacheGet( + generatedBlockOperatorCache, new GeneratedBlockOperatorKey<>(type, operatorInterface, additionalKey), () -> new GeneratedBlockOperator<>(type, operatorInterface, methodHandleSupplier.get())); return generatedBlockOperator.get(); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfUnchecked(e.getCause()); throw new RuntimeException(e.getCause()); } diff --git a/core/trino-main/src/main/java/io/trino/type/TypeOperatorsCache.java b/core/trino-main/src/main/java/io/trino/type/TypeOperatorsCache.java index b44686b0a63a..9d822d53a01a 100644 --- a/core/trino-main/src/main/java/io/trino/type/TypeOperatorsCache.java +++ b/core/trino-main/src/main/java/io/trino/type/TypeOperatorsCache.java @@ -18,11 +18,11 @@ import io.trino.collect.cache.NonKeyEvictableCache; import org.weakref.jmx.Managed; -import java.util.concurrent.ExecutionException; import java.util.function.BiFunction; import java.util.function.Supplier; import static com.google.common.base.Throwables.throwIfUnchecked; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCacheWithWeakInvalidateAll; public class TypeOperatorsCache @@ -36,9 +36,9 @@ public class TypeOperatorsCache public Object apply(Object operatorConvention, Supplier supplier) { try { - return cache.get(operatorConvention, supplier::get); + return uncheckedCacheGet(cache, operatorConvention, supplier); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfUnchecked(e.getCause()); throw new RuntimeException(e.getCause()); } diff --git a/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java b/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java index 9c5d9433a413..2963e9ab19d9 100644 --- a/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java +++ b/core/trino-main/src/main/java/io/trino/util/FastutilSetHelper.java @@ -31,11 +31,11 @@ import java.util.Collection; import java.util.Objects; import java.util.Set; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verifyNotNull; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.util.SingleAccessMethodCompiler.compileSingleAccessMethod; import static java.lang.Boolean.TRUE; @@ -248,12 +248,13 @@ private static T getGeneratedMethod(Type type, Class operatorInterface, M { try { @SuppressWarnings("unchecked") - GeneratedMethod generatedMethod = (GeneratedMethod) generatedMethodCache.get( + GeneratedMethod generatedMethod = (GeneratedMethod) uncheckedCacheGet( + generatedMethodCache, new MethodKey<>(type, operatorInterface), () -> new GeneratedMethod<>(type, operatorInterface, methodHandle)); return generatedMethod.get(); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfUnchecked(e.getCause()); throw new RuntimeException(e.getCause()); } diff --git a/lib/trino-collect/src/main/java/io/trino/collect/cache/CacheUtils.java b/lib/trino-collect/src/main/java/io/trino/collect/cache/CacheUtils.java new file mode 100644 index 000000000000..601164fe795a --- /dev/null +++ b/lib/trino-collect/src/main/java/io/trino/collect/cache/CacheUtils.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.collect.cache; + +import com.google.common.cache.Cache; + +import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; + +public final class CacheUtils +{ + private CacheUtils() {} + + public static V uncheckedCacheGet(Cache cache, K key, Supplier loader) + { + try { + return cache.get(key, loader::get); + } + catch (ExecutionException e) { + // this can not happen because a supplier can not throw a checked exception + throw new RuntimeException("Unexpected checked exception from cache load", e); + } + } +} diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClientFactory.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClientFactory.java index 33047b118a05..09be5dcafc27 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClientFactory.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClientFactory.java @@ -25,8 +25,8 @@ import javax.inject.Inject; import java.util.Optional; -import java.util.concurrent.ExecutionException; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -64,12 +64,7 @@ public BigQueryClient create(ConnectorSession session) { IdentityCacheMapping.IdentityCacheKey cacheKey = identityCacheMapping.getRemoteUserCacheKey(session); - try { - return clientCache.get(cacheKey, () -> createBigQueryClient(session)); - } - catch (ExecutionException e) { - return createBigQueryClient(session); - } + return uncheckedCacheGet(clientCache, cacheKey, () -> createBigQueryClient(session)); } protected BigQueryClient createBigQueryClient(ConnectorSession session) diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ViewMaterializationCache.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ViewMaterializationCache.java index 09a336e7bbf3..cbaf70e650ae 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ViewMaterializationCache.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ViewMaterializationCache.java @@ -25,18 +25,16 @@ import io.airlift.log.Logger; import io.airlift.units.Duration; import io.trino.collect.cache.NonEvictableCache; -import io.trino.spi.TrinoException; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableNotFoundException; import javax.inject.Inject; import java.util.Optional; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; -import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_VIEW_DESTINATION_TABLE_CREATION_FAILED; import static io.trino.plugin.bigquery.BigQueryUtil.convertToBigQueryException; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -66,12 +64,7 @@ public ViewMaterializationCache(BigQueryConfig config) public TableInfo getCachedTable(BigQueryClient client, String query, Duration viewExpiration, TableInfo remoteTableId) { - try { - return destinationTableCache.get(query, new DestinationTableBuilder(client, viewExpiration, query, createDestinationTable(remoteTableId.getTableId()))); - } - catch (ExecutionException e) { - throw new TrinoException(BIGQUERY_VIEW_DESTINATION_TABLE_CREATION_FAILED, "Error creating destination table", e); - } + return uncheckedCacheGet(destinationTableCache, query, new DestinationTableBuilder(client, viewExpiration, query, createDestinationTable(remoteTableId.getTableId()))); } private TableId createDestinationTable(TableId remoteTableId) @@ -84,7 +77,7 @@ private TableId createDestinationTable(TableId remoteTableId) } private static class DestinationTableBuilder - implements Callable + implements Supplier { private final BigQueryClient bigQueryClient; private final Duration viewExpiration; @@ -100,7 +93,7 @@ private static class DestinationTableBuilder } @Override - public TableInfo call() + public TableInfo get() { return createTableFromQuery(); } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/CachingDeltaLakeStatisticsAccess.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/CachingDeltaLakeStatisticsAccess.java index bef18983c740..5bfc8c6fd2f7 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/CachingDeltaLakeStatisticsAccess.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/statistics/CachingDeltaLakeStatisticsAccess.java @@ -26,9 +26,9 @@ import java.lang.annotation.Target; import java.time.Duration; import java.util.Optional; -import java.util.concurrent.ExecutionException; import static com.google.common.base.Throwables.throwIfInstanceOf; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static java.lang.annotation.ElementType.FIELD; import static java.lang.annotation.ElementType.METHOD; @@ -59,9 +59,9 @@ public CachingDeltaLakeStatisticsAccess(@ForCachingDeltaLakeStatisticsAccess Del public Optional readDeltaLakeStatistics(ConnectorSession session, String tableLocation) { try { - return cache.get(tableLocation, () -> delegate.readDeltaLakeStatistics(session, tableLocation)); + return uncheckedCacheGet(cache, tableLocation, () -> delegate.readDeltaLakeStatistics(session, tableLocation)); } - catch (ExecutionException | UncheckedExecutionException e) { + catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); throw new TrinoException(GENERIC_INTERNAL_ERROR, "Error reading statistics from cache", e.getCause()); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/CachingDirectoryLister.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/CachingDirectoryLister.java index 04ff66eaa0f6..66f2ddd7754b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/CachingDirectoryLister.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/CachingDirectoryLister.java @@ -37,11 +37,11 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static java.util.Objects.requireNonNull; import static org.apache.commons.lang3.StringUtils.isNotEmpty; @@ -96,13 +96,7 @@ public RemoteIterator list(FileSystem fs, Table table, Path p return fs.listLocatedStatus(path); } - ValueHolder cachedValueHolder; - try { - cachedValueHolder = cache.get(path, ValueHolder::new); - } - catch (ExecutionException e) { - throw new RuntimeException(e); // cannot happen - } + ValueHolder cachedValueHolder = uncheckedCacheGet(cache, path, ValueHolder::new); if (cachedValueHolder.getFiles().isPresent()) { return simpleRemoteIterator(cachedValueHolder.getFiles().get()); } diff --git a/plugin/trino-ml/src/main/java/io/trino/plugin/ml/MLFunctions.java b/plugin/trino-ml/src/main/java/io/trino/plugin/ml/MLFunctions.java index 4a39fc05f8d7..50d205f5c184 100644 --- a/plugin/trino-ml/src/main/java/io/trino/plugin/ml/MLFunctions.java +++ b/plugin/trino-ml/src/main/java/io/trino/plugin/ml/MLFunctions.java @@ -24,9 +24,8 @@ import io.trino.spi.function.SqlType; import io.trino.spi.type.StandardTypes; -import java.util.concurrent.ExecutionException; - import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.plugin.ml.type.ClassifierType.BIGINT_CLASSIFIER; import static io.trino.plugin.ml.type.ClassifierType.VARCHAR_CLASSIFIER; @@ -77,11 +76,6 @@ public static double regress(@SqlType(MAP_BIGINT_DOUBLE) Block featuresMap, @Sql private static Model getOrLoadModel(Slice slice) { HashCode modelHash = ModelUtils.modelHash(slice); - try { - return MODEL_CACHE.get(modelHash, () -> ModelUtils.deserialize(slice)); - } - catch (ExecutionException e) { - throw new RuntimeException(e); - } + return uncheckedCacheGet(MODEL_CACHE, modelHash, () -> ModelUtils.deserialize(slice)); } } From 0f1455db3ce3d2fcbedfefd7306b97d4d94f0f1e Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Mon, 7 Mar 2022 17:36:47 -0800 Subject: [PATCH 2/3] Pass LambdaProviders to AccumulatorFactory instead of storing them Remove LambdaProvider instances from AccumulatorFactory to simplify caching of the factories. --- .../aggregation/AccumulatorCompiler.java | 5 ++- .../aggregation/AccumulatorFactory.java | 13 +++++--- .../aggregation/AggregatorFactory.java | 18 ++++++---- .../CompiledAccumulatorFactory.java | 20 +++++++---- .../DistinctAccumulatorFactory.java | 23 ++++++++----- .../OrderedAccumulatorFactory.java | 23 ++++++++----- .../sql/planner/LocalExecutionPlanner.java | 33 ++++++++++--------- .../aggregation/TestAccumulatorCompiler.java | 2 +- .../TestingAggregationFunction.java | 5 +-- .../BenchmarkAggregationFunction.java | 9 ++--- 10 files changed, 89 insertions(+), 62 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 7c837620d574..81b5482db904 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -87,8 +87,7 @@ private AccumulatorCompiler() {} public static AccumulatorFactory generateAccumulatorFactory( BoundSignature boundSignature, AggregationMetadata metadata, - FunctionNullability functionNullability, - List> lambdaProviders) + FunctionNullability functionNullability) { // change types used in Aggregation methods to types used in the core Trino engine to simplify code generation metadata = normalizeAggregationMethods(metadata); @@ -115,7 +114,7 @@ public static AccumulatorFactory generateAccumulatorFactory( return new CompiledAccumulatorFactory( accumulatorConstructor, groupedAccumulatorConstructor, - lambdaProviders); + metadata.getLambdaInterfaces()); } private static Constructor generateAccumulatorClass( diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java index 7fcafff526d5..374ce655a2c2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorFactory.java @@ -13,13 +13,18 @@ */ package io.trino.operator.aggregation; +import java.util.List; +import java.util.function.Supplier; + public interface AccumulatorFactory { - Accumulator createAccumulator(); + List> getLambdaInterfaces(); + + Accumulator createAccumulator(List> lambdaProviders); - Accumulator createIntermediateAccumulator(); + Accumulator createIntermediateAccumulator(List> lambdaProviders); - GroupedAccumulator createGroupedAccumulator(); + GroupedAccumulator createGroupedAccumulator(List> lambdaProviders); - GroupedAccumulator createGroupedIntermediateAccumulator(); + GroupedAccumulator createGroupedIntermediateAccumulator(List> lambdaProviders); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java index 63af443f4c64..968162f39347 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.OptionalInt; +import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -32,6 +33,7 @@ public class AggregatorFactory private final List inputChannels; private final OptionalInt maskChannel; private final boolean spillable; + private final List> lambdaProviders; public AggregatorFactory( AccumulatorFactory accumulatorFactory, @@ -40,7 +42,8 @@ public AggregatorFactory( Type finalType, List inputChannels, OptionalInt maskChannel, - boolean spillable) + boolean spillable, + List> lambdaProviders) { this.accumulatorFactory = requireNonNull(accumulatorFactory, "accumulatorFactory is null"); this.step = requireNonNull(step, "step is null"); @@ -49,6 +52,7 @@ public AggregatorFactory( this.inputChannels = ImmutableList.copyOf(requireNonNull(inputChannels, "inputChannels is null")); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); this.spillable = spillable; + this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null")); checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); } @@ -57,10 +61,10 @@ public Aggregator createAggregator() { Accumulator accumulator; if (step.isInputRaw()) { - accumulator = accumulatorFactory.createAccumulator(); + accumulator = accumulatorFactory.createAccumulator(lambdaProviders); } else { - accumulator = accumulatorFactory.createIntermediateAccumulator(); + accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders); } return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel); } @@ -69,10 +73,10 @@ public GroupedAggregator createGroupedAggregator() { GroupedAccumulator accumulator; if (step.isInputRaw()) { - accumulator = accumulatorFactory.createGroupedAccumulator(); + accumulator = accumulatorFactory.createGroupedAccumulator(lambdaProviders); } else { - accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(); + accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel); } @@ -81,10 +85,10 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan { GroupedAccumulator accumulator; if (step.isInputRaw()) { - accumulator = accumulatorFactory.createGroupedAccumulator(); + accumulator = accumulatorFactory.createGroupedAccumulator(lambdaProviders); } else { - accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(); + accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java index 1b91f2a0c4ab..71f32311d7c3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/CompiledAccumulatorFactory.java @@ -26,20 +26,26 @@ public class CompiledAccumulatorFactory { private final Constructor accumulatorConstructor; private final Constructor groupedAccumulatorConstructor; - private final List> lambdaProviders; + private final List> lambdaInterfaces; public CompiledAccumulatorFactory( Constructor accumulatorConstructor, Constructor groupedAccumulatorConstructor, - List> lambdaProviders) + List> lambdaInterfaces) { this.accumulatorConstructor = requireNonNull(accumulatorConstructor, "accumulatorConstructor is null"); this.groupedAccumulatorConstructor = requireNonNull(groupedAccumulatorConstructor, "groupedAccumulatorConstructor is null"); - this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null")); + this.lambdaInterfaces = ImmutableList.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); } @Override - public Accumulator createAccumulator() + public List> getLambdaInterfaces() + { + return lambdaInterfaces; + } + + @Override + public Accumulator createAccumulator(List> lambdaProviders) { try { return accumulatorConstructor.newInstance(lambdaProviders); @@ -50,7 +56,7 @@ public Accumulator createAccumulator() } @Override - public Accumulator createIntermediateAccumulator() + public Accumulator createIntermediateAccumulator(List> lambdaProviders) { try { return accumulatorConstructor.newInstance(lambdaProviders); @@ -61,7 +67,7 @@ public Accumulator createIntermediateAccumulator() } @Override - public GroupedAccumulator createGroupedAccumulator() + public GroupedAccumulator createGroupedAccumulator(List> lambdaProviders) { try { return groupedAccumulatorConstructor.newInstance(lambdaProviders); @@ -72,7 +78,7 @@ public GroupedAccumulator createGroupedAccumulator() } @Override - public GroupedAccumulator createGroupedIntermediateAccumulator() + public GroupedAccumulator createGroupedIntermediateAccumulator(List> lambdaProviders) { try { return groupedAccumulatorConstructor.newInstance(lambdaProviders); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java index dc99887162d0..aad714ab3def 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DistinctAccumulatorFactory.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Optional; +import java.util.function.Supplier; import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkState; @@ -60,10 +61,16 @@ public DistinctAccumulatorFactory( } @Override - public Accumulator createAccumulator() + public List> getLambdaInterfaces() + { + return delegate.getLambdaInterfaces(); + } + + @Override + public Accumulator createAccumulator(List> lambdaProviders) { return new DistinctAccumulator( - delegate.createAccumulator(), + delegate.createAccumulator(lambdaProviders), argumentTypes, session, joinCompiler, @@ -71,16 +78,16 @@ public Accumulator createAccumulator() } @Override - public Accumulator createIntermediateAccumulator() + public Accumulator createIntermediateAccumulator(List> lambdaProviders) { - return delegate.createIntermediateAccumulator(); + return delegate.createIntermediateAccumulator(lambdaProviders); } @Override - public GroupedAccumulator createGroupedAccumulator() + public GroupedAccumulator createGroupedAccumulator(List> lambdaProviders) { return new DistinctGroupedAccumulator( - delegate.createGroupedAccumulator(), + delegate.createGroupedAccumulator(lambdaProviders), argumentTypes, session, joinCompiler, @@ -88,9 +95,9 @@ public GroupedAccumulator createGroupedAccumulator() } @Override - public GroupedAccumulator createGroupedIntermediateAccumulator() + public GroupedAccumulator createGroupedIntermediateAccumulator(List> lambdaProviders) { - return delegate.createGroupedIntermediateAccumulator(); + return delegate.createGroupedIntermediateAccumulator(lambdaProviders); } private static class DistinctAccumulator diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java index f86aecb9b7fe..468938fe0faa 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedAccumulatorFactory.java @@ -28,6 +28,7 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.type.BigintType.BIGINT; @@ -64,29 +65,35 @@ public OrderedAccumulatorFactory( } @Override - public Accumulator createAccumulator() + public List> getLambdaInterfaces() { - Accumulator accumulator = delegate.createAccumulator(); + return delegate.getLambdaInterfaces(); + } + + @Override + public Accumulator createAccumulator(List> lambdaProviders) + { + Accumulator accumulator = delegate.createAccumulator(lambdaProviders); return new OrderedAccumulator(accumulator, sourceTypes, argumentChannels, orderByChannels, orderings, pagesIndexFactory); } @Override - public Accumulator createIntermediateAccumulator() + public Accumulator createIntermediateAccumulator(List> lambdaProviders) { - return delegate.createIntermediateAccumulator(); + return delegate.createIntermediateAccumulator(lambdaProviders); } @Override - public GroupedAccumulator createGroupedAccumulator() + public GroupedAccumulator createGroupedAccumulator(List> lambdaProviders) { - GroupedAccumulator accumulator = delegate.createGroupedAccumulator(); + GroupedAccumulator accumulator = delegate.createGroupedAccumulator(lambdaProviders); return new OrderingGroupedAccumulator(accumulator, sourceTypes, argumentChannels, orderByChannels, orderings, pagesIndexFactory); } @Override - public GroupedAccumulator createGroupedIntermediateAccumulator() + public GroupedAccumulator createGroupedIntermediateAccumulator(List> lambdaProviders) { - return delegate.createGroupedIntermediateAccumulator(); + return delegate.createGroupedIntermediateAccumulator(lambdaProviders); } private static class OrderedAccumulator diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index a3d3006c9af2..4f092b32fe4f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -3546,25 +3546,11 @@ private AggregatorFactory buildAggregatorFactory( } } - OptionalInt maskChannel = aggregation.getMask().stream() - .mapToInt(value -> source.getLayout().get(value)) - .findAny(); AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(aggregation.getResolvedFunction()); - List lambdaExpressions = aggregation.getArguments().stream() - .filter(LambdaExpression.class::isInstance) - .map(LambdaExpression.class::cast) - .collect(toImmutableList()); - List functionTypes = aggregation.getResolvedFunction().getSignature().getArgumentTypes().stream() - .filter(FunctionType.class::isInstance) - .map(FunctionType.class::cast) - .collect(toImmutableList()); - List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationMetadata.getLambdaInterfaces(), functionTypes); - AccumulatorFactory accumulatorFactory = generateAccumulatorFactory( aggregation.getResolvedFunction().getSignature(), aggregationMetadata, - aggregation.getResolvedFunction().getFunctionNullability(), - lambdaProviders); + aggregation.getResolvedFunction().getFunctionNullability()); if (aggregation.isDistinct()) { accumulatorFactory = new DistinctAccumulatorFactory( @@ -3616,6 +3602,20 @@ private AggregatorFactory buildAggregatorFactory( Type intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); Type finalType = aggregation.getResolvedFunction().getSignature().getReturnType(); + OptionalInt maskChannel = aggregation.getMask().stream() + .mapToInt(value -> source.getLayout().get(value)) + .findAny(); + + List lambdaExpressions = aggregation.getArguments().stream() + .filter(LambdaExpression.class::isInstance) + .map(LambdaExpression.class::cast) + .collect(toImmutableList()); + List functionTypes = aggregation.getResolvedFunction().getSignature().getArgumentTypes().stream() + .filter(FunctionType.class::isInstance) + .map(FunctionType.class::cast) + .collect(toImmutableList()); + List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationMetadata.getLambdaInterfaces(), functionTypes); + return new AggregatorFactory( accumulatorFactory, step, @@ -3623,7 +3623,8 @@ private AggregatorFactory buildAggregatorFactory( finalType, argumentChannels, maskChannel, - !aggregation.isDistinct() && aggregation.getOrderingScheme().isEmpty()); + !aggregation.isDistinct() && aggregation.getOrderingScheme().isEmpty(), + lambdaProviders); } private List> makeLambdaProviders(List lambdaExpressions, List> lambdaInterfaces, List functionTypes) diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index 96fe74825b58..24d174c3750f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -97,7 +97,7 @@ private static void assertGenerateAccumulator(Cl FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false)); // test if we can compile aggregation - AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, metadata, functionNullability, ImmutableList.of()); + AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, metadata, functionNullability); assertThat(accumulatorFactory).isNotNull(); assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, metadata, functionNullability)).isNotNull(); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java index 532c59584046..eb9ea8e3fe91 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java @@ -51,7 +51,7 @@ public TestingAggregationFunction(BoundSignature signature, FunctionNullability .collect(toImmutableList()); intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); this.finalType = signature.getReturnType(); - this.factory = generateAccumulatorFactory(signature, aggregationMetadata, functionNullability, ImmutableList.of()); + this.factory = generateAccumulatorFactory(signature, aggregationMetadata, functionNullability); distinctFactory = new DistinctAccumulatorFactory( factory, parameterTypes, @@ -114,6 +114,7 @@ private AggregatorFactory createAggregatorFactory(Step step, List input finalType, inputChannels, maskChannel, - true); + true, + ImmutableList.of()); } } diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java index 80096874c38f..0e00c363ab30 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java @@ -39,11 +39,7 @@ public BenchmarkAggregationFunction(ResolvedFunction resolvedFunction, Aggregati BoundSignature signature = resolvedFunction.getSignature(); intermediateType = getOnlyElement(aggregationMetadata.getAccumulatorStateDescriptors()).getSerializer().getSerializedType(); finalType = signature.getReturnType(); - accumulatorFactory = generateAccumulatorFactory( - signature, - aggregationMetadata, - resolvedFunction.getFunctionNullability(), - ImmutableList.of()); + accumulatorFactory = generateAccumulatorFactory(signature, aggregationMetadata, resolvedFunction.getFunctionNullability()); } public AggregatorFactory bind(List inputChannels) @@ -55,6 +51,7 @@ public AggregatorFactory bind(List inputChannels) finalType, inputChannels, OptionalInt.empty(), - true); + true, + ImmutableList.of()); } } From 081f153eb3e87cda0cd576e8f5024c2cfa0f864b Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Mon, 7 Mar 2022 18:00:29 -0800 Subject: [PATCH 3/3] Cache generated AccumulatorFactory classes --- .../AggregationWindowFunctionSupplier.java | 2 +- .../window/pattern/MatchAggregation.java | 21 +--- .../sql/planner/LocalExecutionPlanner.java | 112 +++++++++++++++--- 3 files changed, 100 insertions(+), 35 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java b/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java index cea2f6685505..8f38d11dd62b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java @@ -54,7 +54,7 @@ public WindowFunction createWindowFunction(boolean ignoreNulls, List createWindowAccumulator(lambdaProviders), hasRemoveInput); } - private WindowAccumulator createWindowAccumulator(List> lambdaProviders) + public WindowAccumulator createWindowAccumulator(List> lambdaProviders) { try { return constructor.newInstance(lambdaProviders); diff --git a/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java b/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java index 18af37b1adac..0c67e3fdd02c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java @@ -16,9 +16,8 @@ import io.trino.memory.context.AggregatedMemoryContext; import io.trino.memory.context.LocalMemoryContext; import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionNullability; -import io.trino.operator.aggregation.AggregationMetadata; import io.trino.operator.aggregation.WindowAccumulator; +import io.trino.operator.window.AggregationWindowFunctionSupplier; import io.trino.operator.window.MappedWindowIndex; import io.trino.operator.window.matcher.ArrayView; import io.trino.operator.window.pattern.SetEvaluator.SetEvaluatorSupplier; @@ -26,11 +25,9 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import java.lang.reflect.Constructor; import java.util.List; import java.util.function.Supplier; -import static io.trino.operator.aggregation.AccumulatorCompiler.generateWindowAccumulatorClass; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -193,8 +190,7 @@ public static class MatchAggregationInstantiator public MatchAggregationInstantiator( BoundSignature boundSignature, - AggregationMetadata aggregationMetadata, - FunctionNullability functionNullability, + AggregationWindowFunctionSupplier aggregationWindowFunctionSupplier, List argumentChannels, List> lambdaProviders, SetEvaluatorSupplier setEvaluatorSupplier) @@ -203,8 +199,7 @@ public MatchAggregationInstantiator( this.argumentChannels = requireNonNull(argumentChannels, "argumentChannels is null"); this.setEvaluatorSupplier = requireNonNull(setEvaluatorSupplier, "setEvaluatorSupplier is null"); - Constructor constructor = generateWindowAccumulatorClass(boundSignature, aggregationMetadata, functionNullability); - this.accumulatorFactory = () -> createWindowAccumulator(constructor, lambdaProviders); + this.accumulatorFactory = () -> aggregationWindowFunctionSupplier.createWindowAccumulator(lambdaProviders); } public MatchAggregation get(AggregatedMemoryContext memoryContextSupplier) @@ -212,15 +207,5 @@ public MatchAggregation get(AggregatedMemoryContext memoryContextSupplier) requireNonNull(memoryContextSupplier, "memoryContextSupplier is null"); return new MatchAggregation(boundSignature, accumulatorFactory, argumentChannels, setEvaluatorSupplier.get(), memoryContextSupplier); } - - private static WindowAccumulator createWindowAccumulator(Constructor constructor, List> lambdaProviders) - { - try { - return constructor.newInstance(lambdaProviders); - } - catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } - } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 4f092b32fe4f..d557c8691f57 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -14,6 +14,7 @@ package io.trino.sql.planner; import com.google.common.base.VerifyException; +import com.google.common.cache.CacheBuilder; import com.google.common.collect.ContiguousSet; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableBiMap; @@ -31,6 +32,7 @@ import io.airlift.units.DataSize; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.collect.cache.NonEvictableCache; import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.DynamicFilterConfig; import io.trino.execution.ExplainAnalyzeContext; @@ -41,6 +43,8 @@ import io.trino.execution.buffer.OutputBuffer; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.index.IndexManager; +import io.trino.metadata.BoundSignature; +import io.trino.metadata.FunctionId; import io.trino.metadata.FunctionKind; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; @@ -256,6 +260,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -266,6 +271,7 @@ import java.util.stream.IntStream; import static com.google.common.base.Functions.forMap; +import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -291,6 +297,8 @@ import static io.trino.SystemSessionProperties.isExchangeCompressionEnabled; import static io.trino.SystemSessionProperties.isLateMaterializationEnabled; import static io.trino.SystemSessionProperties.isSpillEnabled; +import static io.trino.collect.cache.CacheUtils.uncheckedCacheGet; +import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.operator.DistinctLimitOperator.DistinctLimitOperatorFactory; import static io.trino.operator.HashArraySizeSupplier.incrementalLoadFactorHashArraySizeSupplier; import static io.trino.operator.PipelineExecutionStrategy.GROUPED_EXECUTION; @@ -354,6 +362,7 @@ import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.HOURS; import static java.util.stream.Collectors.partitioningBy; import static java.util.stream.IntStream.range; @@ -391,6 +400,13 @@ public class LocalExecutionPlanner private final ExchangeManagerRegistry exchangeManagerRegistry; private final PositionsAppenderFactory positionsAppenderFactory = new PositionsAppenderFactory(); + private final NonEvictableCache accumulatorFactoryCache = buildNonEvictableCache(CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS)); + private final NonEvictableCache aggregationWindowFunctionSupplierCache = buildNonEvictableCache(CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS)); + @Inject public LocalExecutionPlanner( PlannerContext plannerContext, @@ -1167,11 +1183,13 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext private WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction) { if (resolvedFunction.getFunctionKind() == FunctionKind.AGGREGATE) { - AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(resolvedFunction); - return new AggregationWindowFunctionSupplier( - resolvedFunction.getSignature(), - aggregationMetadata, - resolvedFunction.getFunctionNullability()); + return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()), () -> { + AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(resolvedFunction); + return new AggregationWindowFunctionSupplier( + resolvedFunction.getSignature(), + aggregationMetadata, + resolvedFunction.getFunctionNullability()); + }); } return plannerContext.getFunctionManager().getWindowFunctionImplementation(resolvedFunction); } @@ -1510,10 +1528,11 @@ else if (matchNumberSymbols.contains(pointer.getInputSymbol())) { boolean classifierInvolved = false; + ResolvedFunction resolvedFunction = pointer.getFunction(); AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(pointer.getFunction()); ImmutableList.Builder> builder = ImmutableList.builder(); - List signatureTypes = pointer.getFunction().getSignature().getArgumentTypes(); + List signatureTypes = resolvedFunction.getSignature().getArgumentTypes(); for (int i = 0; i < pointer.getArguments().size(); i++) { builder.add(new SimpleEntry<>(pointer.getArguments().get(i), signatureTypes.get(i))); } @@ -1526,7 +1545,7 @@ else if (matchNumberSymbols.contains(pointer.getInputSymbol())) { .map(LambdaExpression.class::cast) .collect(toImmutableList()); - List functionTypes = pointer.getFunction().getSignature().getArgumentTypes().stream() + List functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream() .filter(FunctionType.class::isInstance) .map(FunctionType.class::cast) .collect(toImmutableList()); @@ -1575,10 +1594,16 @@ else if (symbol.equals(matchNumberArgumentSymbol)) { } } + AggregationWindowFunctionSupplier aggregationWindowFunctionSupplier = uncheckedCacheGet( + aggregationWindowFunctionSupplierCache, + new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()), + () -> new AggregationWindowFunctionSupplier( + resolvedFunction.getSignature(), + aggregationMetadata, + resolvedFunction.getFunctionNullability())); matchAggregations.add(new MatchAggregationInstantiator( - pointer.getFunction().getSignature(), - aggregationMetadata, - pointer.getFunction().getFunctionNullability(), + resolvedFunction.getSignature(), + aggregationWindowFunctionSupplier, valueChannels, lambdaProviders, new SetEvaluatorSupplier(pointer.getSetDescriptor(), mapping))); @@ -3546,11 +3571,15 @@ private AggregatorFactory buildAggregatorFactory( } } + ResolvedFunction resolvedFunction = aggregation.getResolvedFunction(); AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(aggregation.getResolvedFunction()); - AccumulatorFactory accumulatorFactory = generateAccumulatorFactory( - aggregation.getResolvedFunction().getSignature(), - aggregationMetadata, - aggregation.getResolvedFunction().getFunctionNullability()); + AccumulatorFactory accumulatorFactory = uncheckedCacheGet( + accumulatorFactoryCache, + new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()), + () -> generateAccumulatorFactory( + resolvedFunction.getSignature(), + aggregationMetadata, + resolvedFunction.getFunctionNullability())); if (aggregation.isDistinct()) { accumulatorFactory = new DistinctAccumulatorFactory( @@ -3600,7 +3629,7 @@ private AggregatorFactory buildAggregatorFactory( .map(stateDescriptor -> stateDescriptor.getSerializer().getSerializedType()) .collect(toImmutableList()); Type intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); - Type finalType = aggregation.getResolvedFunction().getSignature().getReturnType(); + Type finalType = resolvedFunction.getSignature().getReturnType(); OptionalInt maskChannel = aggregation.getMask().stream() .mapToInt(value -> source.getLayout().get(value)) @@ -3610,7 +3639,7 @@ private AggregatorFactory buildAggregatorFactory( .filter(LambdaExpression.class::isInstance) .map(LambdaExpression.class::cast) .collect(toImmutableList()); - List functionTypes = aggregation.getResolvedFunction().getSignature().getArgumentTypes().stream() + List functionTypes = resolvedFunction.getSignature().getArgumentTypes().stream() .filter(FunctionType.class::isInstance) .map(FunctionType.class::cast) .collect(toImmutableList()); @@ -4169,4 +4198,55 @@ public boolean isClassifierInvolved() return classifierInvolved; } } + + private static class FunctionKey + { + private final FunctionId functionId; + private final BoundSignature boundSignature; + + public FunctionKey(FunctionId functionId, BoundSignature boundSignature) + { + this.functionId = requireNonNull(functionId, "functionId is null"); + this.boundSignature = requireNonNull(boundSignature, "boundSignature is null"); + } + + public FunctionId getFunctionId() + { + return functionId; + } + + public BoundSignature getBoundSignature() + { + return boundSignature; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FunctionKey that = (FunctionKey) o; + return functionId.equals(that.functionId) && + boundSignature.equals(that.boundSignature); + } + + @Override + public int hashCode() + { + return Objects.hash(functionId, boundSignature); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("functionId", functionId) + .add("boundSignature", boundSignature) + .toString(); + } + } }