Skip to content

Commit

Permalink
Ensure caches are not used unsafely
Browse files Browse the repository at this point in the history
Guava's `Cache` and `LoadingCache` have concurrency issues around
invalidation and ongoing loads.

Ensure that
- code uses `EvictableCache` or `EvictableLoadingCache` which fix the
  probem, or
- code uses safety wrappers, `NonEvictableCache`,
  `NonEvictableLoadingCache`, which fail when unsafe invalidation is
  called. Additionally, the interfaces have the unimplemented methods
  marked as `@Deprecated`, to signal the problem as early as possible.
  • Loading branch information
findepi committed Jan 21, 2022
1 parent 7c55910 commit 0747a37
Show file tree
Hide file tree
Showing 63 changed files with 813 additions and 261 deletions.
16 changes: 16 additions & 0 deletions .mvn/modernizer/violations.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@
<comment>Prefer Math.toIntExact(long)</comment>
</violation>

<violation>
<name>com/google/common/cache/CacheBuilder.build:()Lcom/google/common/cache/Cache;</name>
<version>1.8</version>
<comment>Guava Cache has concurrency issues around invalidation and ongoing loads. Use EvictableCache, EvictableLoadingCache, or SafeCaches to build caches.
See https://github.com/trinodb/trino/issues/10512 for more information and see https://github.com/trinodb/trino/issues/10512#issuecomment-1016221168
for why Caffeine does not solve the problem.</comment>
</violation>

<violation>
<name>com/google/common/cache/CacheBuilder.build:(Lcom/google/common/cache/CacheLoader;)Lcom/google/common/cache/LoadingCache;</name>
<version>1.8</version>
<comment>Guava LoadingCache has concurrency issues around invalidation and ongoing loads. Use EvictableCache, EvictableLoadingCache, or SafeCaches to build caches.
See https://github.com/trinodb/trino/issues/10512 for more information and see https://github.com/trinodb/trino/issues/10512#issuecomment-1016221168
for why Caffeine does not solve the problem.</comment>
</violation>

<violation>
<name>org/testng/Assert.assertEquals:(Ljava/lang/Iterable;Ljava/lang/Iterable;)V</name>
<version>1.8</version>
Expand Down
5 changes: 5 additions & 0 deletions client/trino-cli/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@
<artifactId>antlr4-runtime</artifactId>
</dependency>

<dependency>
<groupId>org.gaul</groupId>
<artifactId>modernizer-maven-annotations</artifactId>
</dependency>

<dependency>
<groupId>org.jline</groupId>
<artifactId>jline-reader</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.trino.client.QueryData;
import io.trino.client.StatementClient;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;
import org.jline.reader.Candidate;
import org.jline.reader.Completer;
import org.jline.reader.LineReader;
Expand Down Expand Up @@ -51,12 +52,21 @@ public TableNameCompleter(QueryRunner queryRunner)
{
this.queryRunner = requireNonNull(queryRunner, "queryRunner session was null!");

tableCache = CacheBuilder.newBuilder()
.refreshAfterWrite(RELOAD_TIME_MINUTES, TimeUnit.MINUTES)
.build(asyncReloading(CacheLoader.from(this::listTables), executor));
tableCache = buildUnsafeCache(
CacheBuilder.newBuilder()
.refreshAfterWrite(RELOAD_TIME_MINUTES, TimeUnit.MINUTES),
asyncReloading(CacheLoader.from(this::listTables), executor));

functionCache = CacheBuilder.newBuilder()
.build(CacheLoader.from(this::listFunctions));
functionCache = buildUnsafeCache(
CacheBuilder.newBuilder(),
CacheLoader.from(this::listFunctions));
}

// TODO extract safe caches implementations to a new module and use SafeCaches.buildNonEvictableCache hereAsyncCache
@SuppressModernizer
private static <K, V> LoadingCache<K, V> buildUnsafeCache(CacheBuilder<? super K, ? super V> cacheBuilder, CacheLoader<? super K, V> cacheLoader)
{
return cacheBuilder.build(cacheLoader);
}

private List<String> listTables(String schemaName)
Expand Down
5 changes: 5 additions & 0 deletions core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@
<version>8.4.1</version>
</dependency>

<dependency>
<groupId>org.gaul</groupId>
<artifactId>modernizer-maven-annotations</artifactId>
</dependency>

<dependency>
<groupId>org.jgrapht</groupId>
<artifactId>jgrapht-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
*/
package io.trino.execution;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.airlift.units.Duration;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.ErrorCode;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.ErrorType;
Expand All @@ -29,6 +29,7 @@
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.execution.FailureInjector.InjectedFailureType.TASK_FAILURE;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.spi.ErrorType.EXTERNAL;
import static io.trino.spi.ErrorType.INSUFFICIENT_RESOURCES;
import static io.trino.spi.ErrorType.INTERNAL_ERROR;
Expand All @@ -40,7 +41,7 @@ public class FailureInjector
{
public static final String FAILURE_INJECTION_MESSAGE = "This error is injected by the failure injection service";

private final Cache<Key, InjectedFailure> failures;
private final NonEvictableCache<Key, InjectedFailure> failures;
private final Duration requestTimeout;

@Inject
Expand All @@ -53,9 +54,8 @@ public FailureInjector(FailureInjectionConfig config)

public FailureInjector(Duration expirationPeriod, Duration requestTimeout)
{
failures = CacheBuilder.newBuilder()
.expireAfterWrite(expirationPeriod.toMillis(), MILLISECONDS)
.build();
failures = buildNonEvictableCache(CacheBuilder.newBuilder()
.expireAfterWrite(expirationPeriod.toMillis(), MILLISECONDS));
this.requestTimeout = requireNonNull(requestTimeout, "requestTimeout is null");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.ThreadPoolExecutorMBean;
Expand All @@ -41,6 +40,7 @@
import io.trino.memory.MemoryPoolAssignmentsRequest;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.QueryContext;
import io.trino.plugin.base.cache.NonEvictableLoadingCache;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.spi.VersionEmbedder;
Expand Down Expand Up @@ -81,6 +81,7 @@
import static io.trino.execution.SqlTask.createSqlTask;
import static io.trino.memory.LocalMemoryManager.GENERAL_POOL;
import static io.trino.memory.LocalMemoryManager.RESERVED_POOL;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.spi.StandardErrorCode.ABANDONED_TASK;
import static io.trino.spi.StandardErrorCode.SERVER_SHUTTING_DOWN;
import static java.lang.Math.min;
Expand All @@ -105,8 +106,8 @@ public class SqlTaskManager
private final Duration clientTimeout;

private final LocalMemoryManager localMemoryManager;
private final LoadingCache<QueryId, QueryContext> queryContexts;
private final LoadingCache<TaskId, SqlTask> tasks;
private final NonEvictableLoadingCache<QueryId, QueryContext> queryContexts;
private final NonEvictableLoadingCache<TaskId, SqlTask> tasks;

private final SqlTaskIoStats cachedStats = new SqlTaskIoStats();
private final SqlTaskIoStats finishedTaskStats = new SqlTaskIoStats();
Expand Down Expand Up @@ -165,10 +166,10 @@ public SqlTaskManager(
queryMaxMemoryPerNode = maxQueryMemoryPerNode.toBytes();
queryMaxTotalMemoryPerNode = maxQueryTotalMemoryPerNode.toBytes();

queryContexts = CacheBuilder.newBuilder().weakValues().build(CacheLoader.from(
queryContexts = buildNonEvictableCache(CacheBuilder.newBuilder().weakValues(), CacheLoader.from(
queryId -> createQueryContext(queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, maxQueryTotalMemoryPerNode, queryMaxMemoryPerTask, maxQuerySpillPerNode)));

tasks = CacheBuilder.newBuilder().build(CacheLoader.from(
tasks = buildNonEvictableCache(CacheBuilder.newBuilder(), CacheLoader.from(
taskId -> createSqlTask(
taskId,
locationFactory.createLocalTaskLocation(taskId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.execution.scheduler;

import com.google.common.base.Suppliers;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
Expand All @@ -27,6 +26,7 @@
import io.trino.execution.NodeTaskMap;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.HostAddress;
import io.trino.spi.SplitWeight;

Expand All @@ -46,16 +46,17 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask;
import static io.trino.metadata.NodeState.ACTIVE;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static java.util.Objects.requireNonNull;

public class TopologyAwareNodeSelectorFactory
implements NodeSelectorFactory
{
private static final Logger LOG = Logger.get(TopologyAwareNodeSelectorFactory.class);

private final Cache<InternalNode, Object> inaccessibleNodeLogCache = CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS)
.build();
private final NonEvictableCache<InternalNode, Object> inaccessibleNodeLogCache = buildNonEvictableCache(
CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS));

private final NetworkTopology networkTopology;
private final InternalNodeManager nodeManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Suppliers;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableSetMultimap;
import io.airlift.log.Logger;
Expand All @@ -25,6 +24,7 @@
import io.trino.execution.NodeTaskMap;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.HostAddress;
import io.trino.spi.SplitWeight;

Expand All @@ -42,6 +42,7 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask;
import static io.trino.metadata.NodeState.ACTIVE;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand All @@ -51,9 +52,9 @@ public class UniformNodeSelectorFactory
{
private static final Logger LOG = Logger.get(UniformNodeSelectorFactory.class);

private final Cache<InternalNode, Object> inaccessibleNodeLogCache = CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS)
.build();
private final NonEvictableCache<InternalNode, Object> inaccessibleNodeLogCache = buildNonEvictableCache(
CacheBuilder.newBuilder()
.expireAfterWrite(30, TimeUnit.SECONDS));

private final InternalNodeManager nodeManager;
private final int minCandidates;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
import com.fasterxml.jackson.databind.ser.BeanSerializerFactory;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import com.fasterxml.jackson.databind.type.TypeFactory;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.plugin.base.cache.SafeCaches;

import java.io.IOException;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -92,7 +93,7 @@ private static class InternalTypeSerializer<T>
extends StdSerializer<T>
{
private final TypeSerializer typeSerializer;
private final Cache<Class<?>, JsonSerializer<T>> serializerCache = CacheBuilder.newBuilder().build();
private final NonEvictableCache<Class<?>, JsonSerializer<T>> serializerCache = SafeCaches.buildNonEvictableCache(CacheBuilder.newBuilder());

public InternalTypeSerializer(Class<T> baseClass, TypeIdResolver typeIdResolver)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.metadata;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
Expand Down Expand Up @@ -237,6 +236,7 @@
import io.trino.operator.window.RowNumberFunction;
import io.trino.operator.window.SqlWindowFunction;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.spi.function.InvocationConvention;
Expand Down Expand Up @@ -341,6 +341,7 @@
import static io.trino.operator.scalar.TryCastFunction.TRY_CAST;
import static io.trino.operator.scalar.ZipFunction.ZIP_FUNCTIONS;
import static io.trino.operator.scalar.ZipWithFunction.ZIP_WITH_FUNCTION;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.type.DecimalCasts.BIGINT_TO_DECIMAL_CAST;
import static io.trino.type.DecimalCasts.BOOLEAN_TO_DECIMAL_CAST;
import static io.trino.type.DecimalCasts.DECIMAL_TO_BIGINT_CAST;
Expand Down Expand Up @@ -380,9 +381,9 @@
@ThreadSafe
public class FunctionRegistry
{
private final Cache<FunctionKey, ScalarFunctionImplementation> specializedScalarCache;
private final Cache<FunctionKey, AggregationMetadata> specializedAggregationCache;
private final Cache<FunctionKey, WindowFunctionSupplier> specializedWindowCache;
private final NonEvictableCache<FunctionKey, ScalarFunctionImplementation> specializedScalarCache;
private final NonEvictableCache<FunctionKey, AggregationMetadata> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, WindowFunctionSupplier> specializedWindowCache;
private volatile FunctionMap functions = new FunctionMap();

public FunctionRegistry(
Expand All @@ -398,20 +399,17 @@ public FunctionRegistry(
// with generated classes and/or dynamically-created MethodHandles.
// This might also mitigate problems like deoptimization storm or unintended interpreted execution.

specializedScalarCache = CacheBuilder.newBuilder()
specializedScalarCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, HOURS)
.build();
.expireAfterWrite(1, HOURS));

specializedAggregationCache = CacheBuilder.newBuilder()
specializedAggregationCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, HOURS)
.build();
.expireAfterWrite(1, HOURS));

specializedWindowCache = CacheBuilder.newBuilder()
specializedWindowCache = buildNonEvictableCache(CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, HOURS)
.build();
.expireAfterWrite(1, HOURS));

FunctionListBuilder builder = new FunctionListBuilder()
.window(RowNumberFunction.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.metadata;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -32,6 +31,7 @@
import io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.plugin.base.cache.NonEvictableCache;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
Expand Down Expand Up @@ -147,6 +147,7 @@
import static io.trino.metadata.RedirectionAwareTableHandle.withRedirectionTo;
import static io.trino.metadata.Signature.mangleOperatorName;
import static io.trino.metadata.SignatureBinder.applyBoundVariables;
import static io.trino.plugin.base.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND;
Expand Down Expand Up @@ -180,8 +181,8 @@ public final class MetadataManager

private final ResolvedFunctionDecoder functionDecoder;

private final Cache<OperatorCacheKey, ResolvedFunction> operatorCache;
private final Cache<CoercionCacheKey, ResolvedFunction> coercionCache;
private final NonEvictableCache<OperatorCacheKey, ResolvedFunction> operatorCache;
private final NonEvictableCache<CoercionCacheKey, ResolvedFunction> coercionCache;

@Inject
public MetadataManager(
Expand All @@ -204,13 +205,8 @@ public MetadataManager(

functionDecoder = new ResolvedFunctionDecoder(typeManager::getType);

operatorCache = CacheBuilder.newBuilder()
.maximumSize(1000)
.build();

coercionCache = CacheBuilder.newBuilder()
.maximumSize(1000)
.build();
operatorCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000));
coercionCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000));
}

public static MetadataManager createTestMetadataManager()
Expand Down
Loading

0 comments on commit 0747a37

Please sign in to comment.