diff --git a/.gitignore b/.gitignore index f1d9472f81de7..7d4f797920d9f 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ benchmark_outputs *.class .checkstyle .mvn/timing.properties +.editorconfig diff --git a/.travis.yml b/.travis.yml index 9b48214fd8ece..96bacd1d25bae 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,11 +16,17 @@ env: - TEST_SPECIFIC_MODULES=presto-cassandra - TEST_SPECIFIC_MODULES=presto-hive - TEST_OTHER_MODULES=!presto-tests,!presto-raptor,!presto-accumulo,!presto-cassandra,!presto-hive,!presto-docs,!presto-server,!presto-server-rpm - - PRODUCT_TESTS=true + - PRODUCT_TESTS_BASIC_ENVIRONMENT=true + - PRODUCT_TESTS_SPECIFIC_ENVIRONMENT=true - HIVE_TESTS=true sudo: required dist: trusty +group: deprecated-2017Q2 +addons: + apt: + packages: + - oracle-java8-installer cache: directories: @@ -40,7 +46,7 @@ install: ./mvnw install $MAVEN_FAST_INSTALL -pl '!presto-docs,!presto-server,!presto-server-rpm' fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_BASIC_ENVIRONMENT || -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then ./mvnw install $MAVEN_FAST_INSTALL -pl '!presto-docs,!presto-server-rpm' fi - | @@ -62,28 +68,33 @@ script: ./mvnw test $MAVEN_SKIP_CHECKS_AND_DOCS -B -pl $TEST_OTHER_MODULES fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_BASIC_ENVIRONMENT ]]; then presto-product-tests/bin/run_on_docker.sh \ multinode -x quarantine,big_query,storage_formats,profile_specific_tests,tpcds fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then presto-product-tests/bin/run_on_docker.sh \ singlenode-kerberos-hdfs-impersonation -g storage_formats,cli,hdfs_impersonation,authorization fi - | - if [[ -v PRODUCT_TESTS ]]; then + if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then presto-product-tests/bin/run_on_docker.sh \ - singlenode-ldap -g ldap_cli + singlenode-ldap -g ldap -x simba_jdbc fi # SQL server image sporadically hangs during the startup # TODO: Uncomment it once issue is fixed # https://github.com/Microsoft/mssql-docker/issues/76 # - | -# if [[ -v PRODUCT_TESTS ]]; then +# if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then # presto-product-tests/bin/run_on_docker.sh \ # singlenode-sqlserver -g sqlserver # fi + - | + if [[ -v PRODUCT_TESTS_SPECIFIC_ENVIRONMENT ]]; then + presto-product-tests/bin/run_on_docker.sh \ + multinode-tls -g smoke,cli,group-by,join,tls + fi - | if [[ -v HIVE_TESTS ]]; then presto-hive-hadoop2/bin/run_on_docker.sh diff --git a/README.md b/README.md index 7df5380c68c48..acfe40570b3da 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,13 @@ See the [User Manual](https://prestodb.io/docs/current/) for deployment instruct Presto is a standard Maven project. Simply run the following command from the project root directory: - mvn clean install + ./mvnw clean install On the first build, Maven will download all the dependencies from the internet and cache them in the local repository (`~/.m2/repository`), which can take a considerable amount of time. Subsequent builds will be faster. Presto has a comprehensive set of unit tests that can take several minutes to run. You can disable the tests when building: - mvn clean install -DskipTests + ./mvnw clean install -DskipTests ## Running Presto in your IDE diff --git a/pom.xml b/pom.xml index c7eb600de48c7..e733d6a8c3d60 100644 --- a/pom.xml +++ b/pom.xml @@ -5,12 +5,12 @@ io.airlift airbase - 62 + 64 com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT pom presto-root @@ -47,15 +47,15 @@ 3.3.9 4.6 - 0.145 + 0.148 ${dep.airlift.version} 0.29 1.11.30 - 1.30 + 3.8.1 + 1.31 6.10 - - true - None + 0.15.1 + 0.15.2 Asia/Katmandu @@ -81,10 +81,7 @@ presto-orc presto-rcfile presto-hive - presto-hive-hadoop1 presto-hive-hadoop2 - presto-hive-cdh4 - presto-hive-cdh5 presto-teradata-functions presto-example-http presto-local-file @@ -114,6 +111,9 @@ presto-plugin-toolkit presto-resource-group-managers presto-benchto-benchmarks + presto-thrift-connector-api + presto-thrift-testing-server + presto-thrift-connector @@ -180,13 +180,6 @@ ${project.version} - - com.facebook.presto - presto-hive-cdh4 - ${project.version} - zip - - com.facebook.presto presto-example-http @@ -337,32 +330,46 @@ com.facebook.presto.hadoop - hadoop-apache1 - 0.4 + hadoop-apache2 + 2.7.3-1 - com.facebook.presto.hadoop - hadoop-apache2 - 0.10 + com.facebook.presto.hive + hive-apache + 1.2.0-2 - com.facebook.presto.hadoop - hadoop-cdh4 - 0.10 + com.facebook.presto.orc + orc-protobuf + 3 - com.facebook.presto.hive - hive-apache - 1.2.0-1 + com.facebook.presto + presto-thrift-connector-api + ${project.version} - com.facebook.presto.orc - orc-protobuf - 2 + com.facebook.presto + presto-thrift-connector-api + ${project.version} + test-jar + + + + com.facebook.presto + presto-thrift-testing-server + ${project.version} + + + + com.facebook.presto + presto-thrift-connector + ${project.version} + zip @@ -374,7 +381,7 @@ io.airlift aircompressor - 0.5 + 0.7 @@ -455,6 +462,12 @@ ${dep.airlift.version} + + io.airlift + jaxrs-testing + ${dep.airlift.version} + + io.airlift jmx @@ -566,7 +579,7 @@ mysql mysql-connector-java - 5.1.35 + 5.1.41 @@ -619,6 +632,54 @@ 2.78 + + com.squareup.okhttp3 + okhttp + ${dep.okhttp.version} + + + + com.squareup.okhttp3 + mockwebserver + ${dep.okhttp.version} + + + + com.facebook.swift + swift-annotations + ${dep.swift.version} + + + + com.facebook.swift + swift-codec + ${dep.swift.version} + + + + com.facebook.swift + swift-service + ${dep.swift.version} + + + + com.facebook.swift + swift-javadoc + ${dep.swift.version} + + + + com.facebook.nifty + nifty-core + ${dep.nifty.version} + + + + com.facebook.nifty + nifty-client + ${dep.nifty.version} + + org.apache.thrift libthrift @@ -897,7 +958,7 @@ org.codehaus.mojo exec-maven-plugin - 1.2.1 + 1.6.0 @@ -919,6 +980,27 @@ + + com.ning.maven.plugins + maven-dependency-versions-check-plugin + + + + com.google.inject + guice + 4.0-beta5 + 4.0 + + + com.google.inject.extensions + guice-multibindings + 4.0-beta5 + 4.0 + + + + + @@ -1094,26 +1176,6 @@ - - - cli - - - - org.codehaus.mojo - exec-maven-plugin - - ${cli.skip-execute} - ${java.home}/bin/java - ${cli.main-class} - - --debug - - - - - - eclipse-compiler diff --git a/presto-accumulo/pom.xml b/presto-accumulo/pom.xml index e3854c50ab730..4645eb6b3bdb9 100644 --- a/presto-accumulo/pom.xml +++ b/presto-accumulo/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-accumulo @@ -317,5 +317,11 @@ testng test + + + javax.annotation + javax.annotation-api + test + diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java index da91ef8c396b6..d3c9cff3683fa 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java @@ -101,17 +101,17 @@ public AccumuloClient( Connector connector, AccumuloConfig config, ZooKeeperMetadataManager metaManager, - AccumuloTableManager tableManager) + AccumuloTableManager tableManager, + IndexLookup indexLookup) throws AccumuloException, AccumuloSecurityException { this.connector = requireNonNull(connector, "connector is null"); this.username = requireNonNull(config, "config is null").getUsername(); this.metaManager = requireNonNull(metaManager, "metaManager is null"); this.tableManager = requireNonNull(tableManager, "tableManager is null"); - this.auths = connector.securityOperations().getUserAuthorizations(username); + this.indexLookup = requireNonNull(indexLookup, "indexLookup is null"); - // Create the index lookup utility - this.indexLookup = new IndexLookup(connector, config, this.auths); + this.auths = connector.securityOperations().getUserAuthorizations(username); } public AccumuloTable createTable(ConnectorTableMetadata meta) @@ -440,9 +440,6 @@ public void dropTable(AccumuloTable table) { SchemaTableName tableName = new SchemaTableName(table.getSchema(), table.getTable()); - // Drop cardinality cache from index lookup - indexLookup.dropCache(tableName.getSchemaName(), tableName.getTableName()); - // Remove the table metadata from Presto if (metaManager.getTable(tableName) != null) { metaManager.deleteTableMetadata(tableName); diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java index 60993e8833a46..32529b4023532 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java @@ -16,6 +16,8 @@ import com.facebook.presto.accumulo.conf.AccumuloConfig; import com.facebook.presto.accumulo.conf.AccumuloSessionProperties; import com.facebook.presto.accumulo.conf.AccumuloTableProperties; +import com.facebook.presto.accumulo.index.ColumnCardinalityCache; +import com.facebook.presto.accumulo.index.IndexLookup; import com.facebook.presto.accumulo.io.AccumuloPageSinkProvider; import com.facebook.presto.accumulo.io.AccumuloRecordSetProvider; import com.facebook.presto.accumulo.metadata.AccumuloTable; @@ -96,6 +98,8 @@ public void configure(Binder binder) binder.bind(AccumuloTableProperties.class).in(Scopes.SINGLETON); binder.bind(ZooKeeperMetadataManager.class).in(Scopes.SINGLETON); binder.bind(AccumuloTableManager.class).in(Scopes.SINGLETON); + binder.bind(IndexLookup.class).in(Scopes.SINGLETON); + binder.bind(ColumnCardinalityCache.class).in(Scopes.SINGLETON); binder.bind(Connector.class).toProvider(ConnectorProvider.class); configBinder(binder).bindConfig(AccumuloConfig.class); diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java index 4330f3619cbb4..d3e09b15a27dc 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java @@ -15,6 +15,7 @@ import io.airlift.configuration.Config; import io.airlift.configuration.ConfigDescription; +import io.airlift.configuration.ConfigSecuritySensitive; import io.airlift.units.Duration; import javax.validation.constraints.Min; @@ -92,6 +93,7 @@ public String getPassword() } @Config(PASSWORD) + @ConfigSecuritySensitive @ConfigDescription("Sets the password for the configured user") public AccumuloConfig setPassword(String password) { diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java index 255204a21e61e..d266cbebb2ffc 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; +import io.airlift.units.Duration; import javax.inject.Inject; @@ -25,6 +26,7 @@ import static com.facebook.presto.spi.session.PropertyMetadata.doubleSessionProperty; import static com.facebook.presto.spi.session.PropertyMetadata.integerSessionProperty; import static com.facebook.presto.spi.session.PropertyMetadata.stringSessionProperty; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; /** * Class contains all session-based properties for the Accumulo connector. @@ -44,6 +46,8 @@ public final class AccumuloSessionProperties private static final String INDEX_LOWEST_CARDINALITY_THRESHOLD = "index_lowest_cardinality_threshold"; private static final String INDEX_METRICS_ENABLED = "index_metrics_enabled"; private static final String SCAN_USERNAME = "scan_username"; + private static final String INDEX_SHORT_CIRCUIT_CARDINALITY_FETCH = "index_short_circuit_cardinality_fetch"; + private static final String INDEX_CARDINALITY_CACHE_POLLING_DURATION = "index_cardinality_cache_polling_duration"; private final List> sessionProperties; @@ -94,7 +98,22 @@ public AccumuloSessionProperties() true, false); - sessionProperties = ImmutableList.of(s1, s2, s3, s4, s5, s6, s7, s8); + PropertyMetadata s9 = booleanSessionProperty( + INDEX_SHORT_CIRCUIT_CARDINALITY_FETCH, + "Short circuit the retrieval of index metrics once any column is less than the lowest cardinality threshold. Default true", + true, + false); + + PropertyMetadata s10 = new PropertyMetadata<>( + INDEX_CARDINALITY_CACHE_POLLING_DURATION, + "Sets the cardinality cache polling duration for short circuit retrieval of index metrics. Default 10ms", + VARCHAR, String.class, + "10ms", + false, + duration -> Duration.valueOf(duration.toString()).toString(), + object -> object); + + sessionProperties = ImmutableList.of(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10); } public List> getSessionProperties() @@ -132,6 +151,11 @@ public static double getIndexSmallCardThreshold(ConnectorSession session) return session.getProperty(INDEX_LOWEST_CARDINALITY_THRESHOLD, Double.class); } + public static Duration getIndexCardinalityCachePollingDuration(ConnectorSession session) + { + return Duration.valueOf(session.getProperty(INDEX_CARDINALITY_CACHE_POLLING_DURATION, String.class)); + } + public static boolean isIndexMetricsEnabled(ConnectorSession session) { return session.getProperty(INDEX_METRICS_ENABLED, Boolean.class); @@ -141,4 +165,9 @@ public static String getScanUsername(ConnectorSession session) { return session.getProperty(SCAN_USERNAME, String.class); } + + public static boolean isIndexShortCircuitEnabled(ConnectorSession session) + { + return session.getProperty(INDEX_SHORT_CIRCUIT_CARDINALITY_FETCH, Boolean.class); + } } diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java index f4afef21a5dfa..8474240c7ffc3 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java @@ -14,39 +14,63 @@ package com.facebook.presto.accumulo.index; import com.facebook.presto.accumulo.conf.AccumuloConfig; -import com.facebook.presto.accumulo.metadata.AccumuloTable; import com.facebook.presto.accumulo.model.AccumuloColumnConstraint; +import com.facebook.presto.spi.PrestoException; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Iterables; +import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; -import com.google.common.collect.TreeMultimap; +import com.google.common.collect.MultimapBuilder; +import io.airlift.concurrent.BoundedExecutor; import io.airlift.log.Logger; import io.airlift.units.Duration; import org.apache.accumulo.core.client.BatchScanner; import org.apache.accumulo.core.client.Connector; +import org.apache.accumulo.core.client.Scanner; import org.apache.accumulo.core.client.TableNotFoundException; import org.apache.accumulo.core.data.Key; import org.apache.accumulo.core.data.PartialKey; import org.apache.accumulo.core.data.Range; import org.apache.accumulo.core.data.Value; import org.apache.accumulo.core.security.Authorizations; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.io.Text; -import javax.annotation.concurrent.GuardedBy; +import javax.annotation.Nonnull; +import javax.annotation.PreDestroy; +import javax.inject.Inject; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.function.Function; import java.util.stream.Collectors; +import static com.facebook.presto.accumulo.AccumuloErrorCode.UNEXPECTED_ACCUMULO_ERROR; +import static com.facebook.presto.accumulo.index.Indexer.CARDINALITY_CQ_AS_TEXT; +import static com.facebook.presto.accumulo.index.Indexer.getIndexColumnFamily; +import static com.facebook.presto.accumulo.index.Indexer.getMetricsTableName; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.Streams.stream; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static java.lang.Long.parseLong; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.TimeUnit.MILLISECONDS; /** * This class is an indexing utility to cache the cardinality of a column value for every table. @@ -57,40 +81,34 @@ */ public class ColumnCardinalityCache { - private final Authorizations auths; private static final Logger LOG = Logger.get(ColumnCardinalityCache.class); private final Connector connector; - private final int size; - private final Duration expireDuration; + private final ExecutorService coreExecutor; + private final BoundedExecutor executorService; + private final LoadingCache cache; - @GuardedBy("this") - private final Map tableToCache = new HashMap<>(); - - public ColumnCardinalityCache( - Connector connector, - AccumuloConfig config, - Authorizations auths) + @Inject + public ColumnCardinalityCache(Connector connector, AccumuloConfig config) { this.connector = requireNonNull(connector, "connector is null"); - this.size = requireNonNull(config, "config is null").getCardinalityCacheSize(); - this.expireDuration = config.getCardinalityCacheExpiration(); - this.auths = requireNonNull(auths, "auths is null"); + int size = requireNonNull(config, "config is null").getCardinalityCacheSize(); + Duration expireDuration = config.getCardinalityCacheExpiration(); + + // Create a bounded executor with a pool size at 4x number of processors + this.coreExecutor = newCachedThreadPool(daemonThreadsNamed("cardinality-lookup-%s")); + this.executorService = new BoundedExecutor(coreExecutor, 4 * Runtime.getRuntime().availableProcessors()); + + LOG.debug("Created new cache size %d expiry %s", size, expireDuration); + cache = CacheBuilder.newBuilder() + .maximumSize(size) + .expireAfterWrite(expireDuration.toMillis(), MILLISECONDS) + .build(new CardinalityCacheLoader()); } - /** - * Deletes any cache for the given table, no-op of table does not exist in the cache - * - * @param schema Schema name - * @param table Table name - */ - public synchronized void deleteCache(String schema, String table) + @PreDestroy + public void shutdown() { - LOG.debug("Deleting cache for %s.%s", schema, table); - if (tableToCache.containsKey(table)) { - // clear the cache and remove it - getTableCache(schema, table).clear(); - tableToCache.remove(table); - } + coreExecutor.shutdownNow(); } /** @@ -99,161 +117,210 @@ public synchronized void deleteCache(String schema, String table) * * @param schema Schema name * @param table Table name + * @param auths Scan authorizations * @param idxConstraintRangePairs Mapping of all ranges for a given constraint + * @param earlyReturnThreshold Smallest acceptable cardinality to return early while other tasks complete + * @param pollingDuration Duration for polling the cardinality completion service * @return An immutable multimap of cardinality to column constraint, sorted by cardinality from smallest to largest * @throws TableNotFoundException If the metrics table does not exist * @throws ExecutionException If another error occurs; I really don't even know anymore. */ - public Multimap getCardinalities(String schema, String table, Multimap idxConstraintRangePairs) + public Multimap getCardinalities(String schema, String table, Authorizations auths, Multimap idxConstraintRangePairs, long earlyReturnThreshold, Duration pollingDuration) throws ExecutionException, TableNotFoundException { - // Create a multi map sorted by cardinality, sort columns by name - TreeMultimap cardinalityToConstraints = TreeMultimap.create( - Long::compare, - (AccumuloColumnConstraint o1, AccumuloColumnConstraint o2) -> o1.getName().compareTo(o2.getName())); - - for (Entry> entry : idxConstraintRangePairs.asMap().entrySet()) { - long card = getColumnCardinality(schema, table, entry.getKey(), entry.getValue()); - LOG.debug("Cardinality for column %s is %s", entry.getKey().getName(), card); - cardinalityToConstraints.put(card, entry.getKey()); + // Submit tasks to the executor to fetch column cardinality, adding it to the Guava cache if necessary + CompletionService> executor = new ExecutorCompletionService<>(executorService); + idxConstraintRangePairs.asMap().forEach((key, value) -> executor.submit(() -> { + long cardinality = getColumnCardinality(schema, table, auths, key.getFamily(), key.getQualifier(), value); + LOG.debug("Cardinality for column %s is %s", key.getName(), cardinality); + return Pair.of(cardinality, key); + } + )); + + // Create a multi map sorted by cardinality + ListMultimap cardinalityToConstraints = MultimapBuilder.treeKeys().arrayListValues().build(); + try { + boolean earlyReturn = false; + int numTasks = idxConstraintRangePairs.asMap().entrySet().size(); + do { + // Sleep for the polling duration to allow concurrent tasks to run for this time + Thread.sleep(pollingDuration.toMillis()); + + // Poll each task, retrieving the result if it is done + for (int i = 0; i < numTasks; ++i) { + Future> futureCardinality = executor.poll(); + if (futureCardinality != null && futureCardinality.isDone()) { + Pair columnCardinality = futureCardinality.get(); + cardinalityToConstraints.put(columnCardinality.getLeft(), columnCardinality.getRight()); + } + } + + // If the smallest cardinality is present and below the threshold, set the earlyReturn flag + Optional> smallestCardinality = cardinalityToConstraints.entries().stream().findFirst(); + if (smallestCardinality.isPresent()) { + if (smallestCardinality.get().getKey() <= earlyReturnThreshold) { + LOG.info("Cardinality %s, is below threshold. Returning early while other tasks finish", smallestCardinality); + earlyReturn = true; + } + } + } + while (!earlyReturn && cardinalityToConstraints.entries().size() < numTasks); + } + catch (ExecutionException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new PrestoException(UNEXPECTED_ACCUMULO_ERROR, "Exception when getting cardinality", e); } + // Create a copy of the cardinalities return ImmutableMultimap.copyOf(cardinalityToConstraints); } /** - * Gets the cardinality for the given column constraint with the given Ranges. - * Ranges can be exact values or a range of values. + * Gets the column cardinality for all of the given range values. May reach out to the + * metrics table in Accumulo to retrieve new cache elements. * - * @param schema Schema name + * @param schema Table schema * @param table Table name - * @param columnConstraint Mapping of all ranges for a given constraint - * @param indexRanges Ranges for each exact or ranged value of the column constraint - * @return The cardinality for the column - * @throws TableNotFoundException If the metrics table does not exist - * @throws ExecutionException If another error occurs; I really don't even know anymore. + * @param auths Scan authorizations + * @param family Accumulo column family + * @param qualifier Accumulo column qualifier + * @param colValues All range values to summarize for the cardinality + * @return The cardinality of the column */ - private long getColumnCardinality(String schema, String table, AccumuloColumnConstraint columnConstraint, Collection indexRanges) - throws ExecutionException, TableNotFoundException + public long getColumnCardinality(String schema, String table, Authorizations auths, String family, String qualifier, Collection colValues) + throws ExecutionException { - return getTableCache(schema, table) - .getColumnCardinality( - columnConstraint.getName(), - columnConstraint.getFamily(), - columnConstraint.getQualifier(), - indexRanges); + LOG.debug("Getting cardinality for %s:%s", family, qualifier); + + // Collect all exact Accumulo Ranges, i.e. single value entries vs. a full scan + Collection exactRanges = colValues.stream() + .filter(ColumnCardinalityCache::isExact) + .map(range -> new CacheKey(schema, table, family, qualifier, range, auths)) + .collect(Collectors.toList()); + + LOG.debug("Column values contain %s exact ranges of %s", exactRanges.size(), colValues.size()); + + // Sum the cardinalities for the exact-value Ranges + // This is where the reach-out to Accumulo occurs for all Ranges that have not + // previously been fetched + long sum = cache.getAll(exactRanges).values().stream().mapToLong(Long::longValue).sum(); + + // If these collection sizes are not equal, + // then there is at least one non-exact range + if (exactRanges.size() != colValues.size()) { + // for each range in the column value + for (Range range : colValues) { + // if this range is not exact + if (!isExact(range)) { + // Then get the value for this range using the single-value cache lookup + sum += cache.get(new CacheKey(schema, table, family, qualifier, range, auths)); + } + } + } + + return sum; } - /** - * Gets the {@link TableColumnCache} for the given table, creating a new one if necessary. - * - * @param schema Schema name - * @param table Table name - * @return An existing or new TableColumnCache - */ - private synchronized TableColumnCache getTableCache(String schema, String table) + private static boolean isExact(Range range) { - String fullName = AccumuloTable.getFullTableName(schema, table); - TableColumnCache cache = tableToCache.get(fullName); - if (cache == null) { - LOG.debug("Creating new TableColumnCache for %s.%s %s", schema, table, this); - cache = new TableColumnCache(schema, table); - tableToCache.put(fullName, cache); - } - return cache; + return !range.isInfiniteStartKey() && !range.isInfiniteStopKey() && + range.getStartKey().followingKey(PartialKey.ROW).equals(range.getEndKey()); } /** - * Internal class for holding the mapping of column names to the LoadingCache + * Complex key for the CacheLoader */ - private class TableColumnCache + private static class CacheKey { - private final Map> columnToCache = new HashMap<>(); private final String schema; private final String table; + private final String family; + private final String qualifier; + private final Range range; + private final Authorizations auths; - public TableColumnCache(String schema, - String table) + public CacheKey( + String schema, + String table, + String family, + String qualifier, + Range range, + Authorizations auths) { - this.schema = schema; - this.table = table; + this.schema = requireNonNull(schema, "schema is null"); + this.table = requireNonNull(table, "table is null"); + this.family = requireNonNull(family, "family is null"); + this.qualifier = requireNonNull(qualifier, "qualifier is null"); + this.range = requireNonNull(range, "range is null"); + this.auths = requireNonNull(auths, "auths is null"); } - /** - * Clears and removes all caches as if the object had been first created - */ - public void clear() + public String getSchema() { - columnToCache.values().forEach(LoadingCache::invalidateAll); - columnToCache.clear(); + return schema; } - /** - * Gets the column cardinality for all of the given range values. - * May reach out to the metrics table in Accumulo to retrieve new cache elements. - * - * @param column Presto column name - * @param family Accumulo column family - * @param qualifier Accumulo column qualifier - * @param colValues All range values to summarize for the cardinality - * @return The cardinality of the column - */ - public long getColumnCardinality(String column, String family, String qualifier, Collection colValues) - throws ExecutionException, TableNotFoundException + public String getTable() { - // Get the column cache for this column, creating a new one if necessary - LoadingCache cache = columnToCache.get(column); - if (cache == null) { - cache = newCache(schema, table, family, qualifier); - columnToCache.put(column, cache); - } + return table; + } - // Collect all exact Accumulo Ranges, i.e. single value entries vs. a full scan - Collection exactRanges = colValues.stream().filter(this::isExact).collect(Collectors.toList()); - LOG.debug("Column values contain %s exact ranges of %s", exactRanges.size(), colValues.size()); + public String getFamily() + { + return family; + } - // Sum the cardinalities for the exact-value Ranges - // This is where the reach-out to Accumulo occurs for all Ranges that have not previously been fetched - long sum = 0; - for (Long value : cache.getAll(exactRanges).values()) { - sum += value; - } + public String getQualifier() + { + return qualifier; + } - // If these collection sizes are not equal, then there is at least one non-exact range - if (exactRanges.size() != colValues.size()) { - // for each range in the column value - for (Range range : colValues) { - // if this range is not exact - if (!isExact(range)) { - // Then get the value for this range using the single-value cache lookup - long value = cache.get(range); - - // add our value to the cache and our sum - cache.put(range, value); - sum += value; - } - } - } + public Range getRange() + { + return range; + } - LOG.debug("Cache stats : size=%s, %s", cache.size(), cache.stats()); - return sum; + public Authorizations getAuths() + { + return auths; } - private boolean isExact(Range range) + @Override + public int hashCode() { - return !range.isInfiniteStartKey() - && !range.isInfiniteStopKey() - && range.getStartKey().followingKey(PartialKey.ROW).equals(range.getEndKey()); + return Objects.hash(schema, table, family, qualifier, range); } - private LoadingCache newCache(String schema, String table, String family, String qualifier) + @Override + public boolean equals(Object obj) { - LOG.debug("Created new cache for %s.%s, column %s:%s, size %s expiry %s", schema, table, family, qualifier, size, expireDuration); - return CacheBuilder - .newBuilder() - .maximumSize(size) - .expireAfterWrite(expireDuration.toMillis(), TimeUnit.MILLISECONDS) - .build(new CardinalityCacheLoader(schema, table, family, qualifier)); + if (this == obj) { + return true; + } + + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + + CacheKey other = (CacheKey) obj; + return Objects.equals(this.schema, other.schema) + && Objects.equals(this.table, other.table) + && Objects.equals(this.family, other.family) + && Objects.equals(this.qualifier, other.qualifier) + && Objects.equals(this.range, other.range); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schema", schema) + .add("table", table) + .add("family", family) + .add("qualifier", qualifier) + .add("range", range).toString(); } } @@ -261,23 +328,8 @@ private LoadingCache newCache(String schema, String table, String f * Internal class for loading the cardinality from Accumulo */ private class CardinalityCacheLoader - extends CacheLoader + extends CacheLoader { - private final String metricsTable; - private final Text columnFamily; - - public CardinalityCacheLoader( - String schema, - String table, - String family, - String qualifier) - { - this.metricsTable = Indexer.getMetricsTableName(schema, table); - - // Create the column family for our scanners - this.columnFamily = new Text(Indexer.getIndexColumnFamily(family.getBytes(UTF_8), qualifier.getBytes(UTF_8)).array()); - } - /** * Loads the cardinality for the given Range. Uses a BatchScanner and sums the cardinality for all values that encapsulate the Range. * @@ -285,64 +337,77 @@ public CardinalityCacheLoader( * @return The cardinality of the column, which would be zero if the value does not exist */ @Override - public Long load(Range key) + public Long load(@Nonnull CacheKey key) throws Exception { - // Create a BatchScanner against our metrics table, setting the value range and fetching the appropriate column - BatchScanner scanner = connector.createBatchScanner(metricsTable, auths, 10); - scanner.setRanges(ImmutableList.of(key)); - scanner.fetchColumn(columnFamily, Indexer.CARDINALITY_CQ_AS_TEXT); - - // Sum all those entries! - long sum = 0; - for (Entry entry : scanner) { - sum += Long.parseLong(entry.getValue().toString()); + LOG.debug("Loading a non-exact range from Accumulo: %s", key); + // Get metrics table name and the column family for the scanner + String metricsTable = getMetricsTableName(key.getSchema(), key.getTable()); + Text columnFamily = new Text(getIndexColumnFamily(key.getFamily().getBytes(UTF_8), key.getQualifier().getBytes(UTF_8)).array()); + + // Create scanner for querying the range + Scanner scanner = connector.createScanner(metricsTable, key.getAuths()); + scanner.setRange(key.getRange()); + scanner.fetchColumn(columnFamily, CARDINALITY_CQ_AS_TEXT); + + try { + return stream(scanner) + .map(Entry::getValue) + .map(Value::toString) + .mapToLong(Long::parseLong) + .sum(); + } + finally { + scanner.close(); } - - scanner.close(); - return sum; } - @SuppressWarnings("unchecked") @Override - public Map loadAll(Iterable keys) + public Map loadAll(@Nonnull Iterable keys) throws Exception { - LOG.debug("Loading %s exact ranges from Accumulo", ((Collection) keys).size()); - - // Create batch scanner for querying all ranges - BatchScanner scanner = connector.createBatchScanner(metricsTable, auths, 10); - scanner.setRanges((Collection) keys); - scanner.fetchColumn(columnFamily, Indexer.CARDINALITY_CQ_AS_TEXT); - - // Create a new map to hold our cardinalities for each range, returning a default of zero for each non-existent Key - Map rangeValues = new MapDefaultZero(); - for (Entry entry : scanner) { - rangeValues.put( - Range.exact(entry.getKey().getRow()), - Long.parseLong(entry.getValue().toString())); + int size = Iterables.size(keys); + if (size == 0) { + return ImmutableMap.of(); } - scanner.close(); - return rangeValues; - } + LOG.debug("Loading %s exact ranges from Accumulo", size); - /** - * We extend HashMap here and override get to return a value of zero if the key is not in the map. - * This mitigates the CacheLoader InvalidCacheLoadException if loadAll fails to return a value for a given key, - * which occurs when there is no key in Accumulo. - */ - public class MapDefaultZero - extends HashMap - { - @Override - public Long get(Object key) - { - // Get the key from our map overlord - Long value = super.get(key); - - // Return zero if null - return value == null ? 0 : value; + // In order to simplify the implementation, we are making a (safe) assumption + // that the CacheKeys will all contain the same combination of schema/table/family/qualifier + // This is asserted with the below implementation error just to make sure + CacheKey anyKey = stream(keys).findAny().get(); + if (stream(keys).anyMatch(k -> !k.getSchema().equals(anyKey.getSchema()) || !k.getTable().equals(anyKey.getTable()) || !k.getFamily().equals(anyKey.getFamily()) || !k.getQualifier().equals(anyKey.getQualifier()))) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "loadAll called with a non-homogeneous collection of cache keys"); + } + + Map rangeToKey = stream(keys).collect(Collectors.toMap(CacheKey::getRange, Function.identity())); + LOG.debug("rangeToKey size is %s", rangeToKey.size()); + + // Get metrics table name and the column family for the scanner + String metricsTable = getMetricsTableName(anyKey.getSchema(), anyKey.getTable()); + Text columnFamily = new Text(getIndexColumnFamily(anyKey.getFamily().getBytes(UTF_8), anyKey.getQualifier().getBytes(UTF_8)).array()); + + BatchScanner scanner = connector.createBatchScanner(metricsTable, anyKey.getAuths(), 10); + try { + scanner.setRanges(stream(keys).map(CacheKey::getRange).collect(Collectors.toList())); + scanner.fetchColumn(columnFamily, CARDINALITY_CQ_AS_TEXT); + + // Create a new map to hold our cardinalities for each range, returning a default of + // Zero for each non-existent Key + Map rangeValues = new HashMap<>(); + stream(keys).forEach(key -> rangeValues.put(key, 0L)); + + for (Entry entry : scanner) { + rangeValues.put(rangeToKey.get(Range.exact(entry.getKey().getRow())), parseLong(entry.getValue().toString())); + } + + return rangeValues; + } + finally { + if (scanner != null) { + scanner.close(); + } } } } diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java index e5c359f5bf166..8c0ee86ab2635 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java @@ -13,9 +13,6 @@ */ package com.facebook.presto.accumulo.index; -import com.facebook.presto.accumulo.AccumuloClient; -import com.facebook.presto.accumulo.conf.AccumuloConfig; -import com.facebook.presto.accumulo.conf.AccumuloSessionProperties; import com.facebook.presto.accumulo.model.AccumuloColumnConstraint; import com.facebook.presto.accumulo.model.TabletSplitMetadata; import com.facebook.presto.accumulo.serializers.AccumuloRowSerializer; @@ -26,6 +23,7 @@ import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; import io.airlift.log.Logger; +import io.airlift.units.Duration; import org.apache.accumulo.core.client.AccumuloException; import org.apache.accumulo.core.client.AccumuloSecurityException; import org.apache.accumulo.core.client.BatchScanner; @@ -38,13 +36,30 @@ import org.apache.accumulo.core.security.Authorizations; import org.apache.hadoop.io.Text; +import javax.inject.Inject; + import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map.Entry; import java.util.Optional; import java.util.Set; - +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.accumulo.AccumuloClient.getRangesFromDomain; +import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.getIndexCardinalityCachePollingDuration; +import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.getIndexSmallCardThreshold; +import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.getIndexThreshold; +import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.getNumIndexRowsPerSplit; +import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.isIndexMetricsEnabled; +import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.isIndexShortCircuitEnabled; +import static com.facebook.presto.accumulo.conf.AccumuloSessionProperties.isOptimizeIndexEnabled; +import static com.facebook.presto.accumulo.index.Indexer.CARDINALITY_CQ_AS_TEXT; +import static com.facebook.presto.accumulo.index.Indexer.METRICS_TABLE_ROWID_AS_TEXT; +import static com.facebook.presto.accumulo.index.Indexer.METRICS_TABLE_ROWS_CF_AS_TEXT; +import static com.facebook.presto.accumulo.index.Indexer.getIndexColumnFamily; +import static com.facebook.presto.accumulo.index.Indexer.getIndexTableName; +import static com.facebook.presto.accumulo.index.Indexer.getMetricsTableName; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.google.common.base.Preconditions.checkArgument; import static java.nio.charset.StandardCharsets.UTF_8; @@ -59,22 +74,15 @@ public class IndexLookup { private static final Logger LOG = Logger.get(IndexLookup.class); - private static final Range METRICS_TABLE_ROWID_RANGE = new Range(Indexer.METRICS_TABLE_ROWID_AS_TEXT); + private static final Range METRICS_TABLE_ROWID_RANGE = new Range(METRICS_TABLE_ROWID_AS_TEXT); private final ColumnCardinalityCache cardinalityCache; private final Connector connector; - public IndexLookup( - Connector connector, - AccumuloConfig config, - Authorizations auths) + @Inject + public IndexLookup(Connector connector, ColumnCardinalityCache cardinalityCache) { this.connector = requireNonNull(connector, "connector is null"); - this.cardinalityCache = new ColumnCardinalityCache(connector, requireNonNull(config, "config is null"), auths); - } - - public void dropCache(String schema, String table) - { - cardinalityCache.deleteCache(schema, table); + this.cardinalityCache = requireNonNull(cardinalityCache, "cardinalityCache is null"); } /** @@ -111,7 +119,7 @@ public boolean applyIndex( throws Exception { // Early out if index is disabled - if (!AccumuloSessionProperties.isOptimizeIndexEnabled(session)) { + if (!isOptimizeIndexEnabled(session)) { LOG.debug("Secondary index is disabled"); return false; } @@ -128,14 +136,14 @@ public boolean applyIndex( } // If metrics are not enabled - if (!AccumuloSessionProperties.isIndexMetricsEnabled(session)) { + if (!isIndexMetricsEnabled(session)) { LOG.debug("Use of index metrics is disabled"); // Get the ranges via the index table - List indexRanges = getIndexRanges(Indexer.getIndexTableName(schema, table), constraintRanges, rowIdRanges, auths); + List indexRanges = getIndexRanges(getIndexTableName(schema, table), constraintRanges, rowIdRanges, auths); if (!indexRanges.isEmpty()) { // Bin the ranges into TabletMetadataSplits and return true to use the tablet splits - binRanges(AccumuloSessionProperties.getNumIndexRowsPerSplit(session), indexRanges, tabletSplits); + binRanges(getNumIndexRowsPerSplit(session), indexRanges, tabletSplits); LOG.debug("Number of splits for %s.%s is %d with %d ranges", schema, table, tabletSplits.size(), indexRanges.size()); } else { @@ -157,14 +165,12 @@ private static Multimap getIndexedConstraintRan ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); for (AccumuloColumnConstraint columnConstraint : constraints) { if (columnConstraint.isIndexed()) { - for (Range range : AccumuloClient.getRangesFromDomain(columnConstraint.getDomain(), serializer)) { + for (Range range : getRangesFromDomain(columnConstraint.getDomain(), serializer)) { builder.put(columnConstraint, range); } } else { - LOG.warn( - "Query containts constraint on non-indexed column %s. Is it worth indexing?", - columnConstraint.getName()); + LOG.warn("Query containts constraint on non-indexed column %s. Is it worth indexing?", columnConstraint.getName()); } } return builder.build(); @@ -180,18 +186,33 @@ private boolean getRangesWithMetrics( Authorizations auths) throws Exception { + String metricsTable = getMetricsTableName(schema, table); + long numRows = getNumRowsInTable(metricsTable, auths); + // Get the cardinalities from the metrics table - Multimap cardinalities = cardinalityCache.getCardinalities(schema, table, constraintRanges); + Multimap cardinalities; + if (isIndexShortCircuitEnabled(session)) { + cardinalities = cardinalityCache.getCardinalities( + schema, + table, + auths, + constraintRanges, + (long) (numRows * getIndexSmallCardThreshold(session)), + getIndexCardinalityCachePollingDuration(session)); + } + else { + // disable short circuit using 0 + cardinalities = cardinalityCache.getCardinalities(schema, table, auths, constraintRanges, 0, new Duration(0, TimeUnit.MILLISECONDS)); + } + Optional> entry = cardinalities.entries().stream().findFirst(); if (!entry.isPresent()) { return false; } Entry lowestCardinality = entry.get(); - String indexTable = Indexer.getIndexTableName(schema, table); - String metricsTable = Indexer.getMetricsTableName(schema, table); - long numRows = getNumRowsInTable(metricsTable, auths); - double threshold = AccumuloSessionProperties.getIndexThreshold(session); + String indexTable = getIndexTableName(schema, table); + double threshold = getIndexThreshold(session); List indexRanges; // If the smallest cardinality in our list is above the lowest cardinality threshold, @@ -235,7 +256,7 @@ private boolean getRangesWithMetrics( // If the percentage of scanned rows, the ratio, less than the configured threshold if (ratio < threshold) { // Bin the ranges into TabletMetadataSplits and return true to use the tablet splits - binRanges(AccumuloSessionProperties.getNumIndexRowsPerSplit(session), indexRanges, tabletSplits); + binRanges(getNumIndexRowsPerSplit(session), indexRanges, tabletSplits); LOG.debug("Number of splits for %s.%s is %d with %d ranges", schema, table, tabletSplits.size(), indexRanges.size()); return true; } @@ -248,7 +269,7 @@ private boolean getRangesWithMetrics( private static boolean smallestCardAboveThreshold(ConnectorSession session, long numRows, long smallestCardinality) { double ratio = ((double) smallestCardinality / (double) numRows); - double threshold = AccumuloSessionProperties.getIndexSmallCardThreshold(session); + double threshold = getIndexSmallCardThreshold(session); LOG.debug("Smallest cardinality is %d, num rows is %d, ratio is %2f with threshold of %f", smallestCardinality, numRows, ratio, threshold); return ratio > threshold; } @@ -259,7 +280,7 @@ private long getNumRowsInTable(String metricsTable, Authorizations auths) // Create scanner against the metrics table, pulling the special column and the rows column Scanner scanner = connector.createScanner(metricsTable, auths); scanner.setRange(METRICS_TABLE_ROWID_RANGE); - scanner.fetchColumn(Indexer.METRICS_TABLE_ROWS_CF_AS_TEXT, Indexer.CARDINALITY_CQ_AS_TEXT); + scanner.fetchColumn(METRICS_TABLE_ROWS_CF_AS_TEXT, CARDINALITY_CQ_AS_TEXT); // Scan the entry and get the number of rows long numRows = -1; @@ -286,10 +307,7 @@ private List getIndexRanges(String indexTable, Multimap map) Type keyType = mapType.getTypeParameters().get(0); Type valueType = mapType.getTypeParameters().get(1); - BlockBuilder builder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueType), new BlockBuilderStatus(), map.size() * 2); + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + BlockBuilder builder = mapBlockBuilder.beginBlockEntry(); for (Entry entry : map.entrySet()) { writeObject(builder, keyType, entry.getKey()); writeObject(builder, valueType, entry.getValue()); } - return builder.build(); + + mapBlockBuilder.closeEntry(); + return (Block) mapType.getObject(mapBlockBuilder, 0); } /** diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloClient.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloClient.java index 2d4da61560ba4..21d86ecce4253 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloClient.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloClient.java @@ -15,6 +15,8 @@ import com.facebook.presto.accumulo.conf.AccumuloConfig; import com.facebook.presto.accumulo.conf.AccumuloTableProperties; +import com.facebook.presto.accumulo.index.ColumnCardinalityCache; +import com.facebook.presto.accumulo.index.IndexLookup; import com.facebook.presto.accumulo.metadata.AccumuloTable; import com.facebook.presto.accumulo.metadata.ZooKeeperMetadataManager; import com.facebook.presto.spi.ColumnMetadata; @@ -47,7 +49,7 @@ public TestAccumuloClient() Connector connector = AccumuloQueryRunner.getAccumuloConnector(); config.setZooKeepers(connector.getInstance().getZooKeepers()); zooKeeperMetadataManager = new ZooKeeperMetadataManager(config, new TypeRegistry()); - client = new AccumuloClient(connector, config, zooKeeperMetadataManager, new AccumuloTableManager(connector)); + client = new AccumuloClient(connector, config, zooKeeperMetadataManager, new AccumuloTableManager(connector), new IndexLookup(connector, new ColumnCardinalityCache(connector, config))); } @Test diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java index 6160553aac845..1282ca71ef434 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/index/TestIndexer.java @@ -17,8 +17,8 @@ import com.facebook.presto.accumulo.model.AccumuloColumnHandle; import com.facebook.presto.accumulo.serializers.AccumuloRowSerializer; import com.facebook.presto.accumulo.serializers.LexicoderRowSerializer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import org.apache.accumulo.core.client.BatchWriterConfig; import org.apache.accumulo.core.client.Connector; diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java index 6ae3492f9598c..55db1b1e98e27 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestField.java @@ -14,10 +14,16 @@ package com.facebook.presto.accumulo.model; import com.facebook.presto.accumulo.serializers.AccumuloRowSerializer; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -172,7 +178,13 @@ public void testLong() public void testMap() throws Exception { - Type type = new MapType(VARCHAR, BIGINT); + TypeManager typeManager = new TypeRegistry(); + // associate typeManager with a function registry + new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); + + Type type = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(VARCHAR.getTypeSignature()), + TypeSignatureParameter.of(BIGINT.getTypeSignature()))); Block expected = AccumuloRowSerializer.getBlockFromMap(type, ImmutableMap.of("a", 1L, "b", 2L, "c", 3L)); Field f1 = new Field(expected, type); assertEquals(f1.getMap(), expected); diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java index 463cb01b19db6..596c81b0c4ea2 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/model/TestRow.java @@ -14,7 +14,7 @@ package com.facebook.presto.accumulo.model; import com.facebook.presto.accumulo.serializers.AccumuloRowSerializer; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -22,7 +22,6 @@ import java.sql.Timestamp; import java.util.GregorianCalendar; import java.util.Optional; -import java.util.concurrent.TimeUnit; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -83,7 +82,7 @@ public void testRowFromString() Row expected = new Row(); expected.addField(new Field(AccumuloRowSerializer.getBlockFromArray(VARCHAR, ImmutableList.of("a", "b", "c")), new ArrayType(VARCHAR))); expected.addField(true, BOOLEAN); - expected.addField(new Field(new Date(TimeUnit.MILLISECONDS.toDays(new GregorianCalendar(1999, 0, 1).getTime().getTime())), DATE)); + expected.addField(new Field(new Date(new GregorianCalendar(1999, 0, 1).getTime().getTime()), DATE)); expected.addField(123.45678, DOUBLE); expected.addField(new Field(123.45678f, REAL)); expected.addField(12345678, INTEGER); diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java index 47d6cd153a797..dec82d6ac2194 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/serializers/AbstractTestAccumuloRowSerializer.java @@ -13,9 +13,15 @@ */ package com.facebook.presto.accumulo.serializers; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.accumulo.core.data.Key; @@ -182,8 +188,14 @@ public void testLong() public void testMap() throws Exception { + TypeManager typeManager = new TypeRegistry(); + // associate typeManager with a function registry + new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); + AccumuloRowSerializer serializer = serializerClass.getConstructor().newInstance(); - Type type = new MapType(VARCHAR, BIGINT); + Type type = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(VARCHAR.getTypeSignature()), + TypeSignatureParameter.of(BIGINT.getTypeSignature()))); Map expected = ImmutableMap.of("a", 1L, "b", 2L, "3", 3L); byte[] data = serializer.encode(type, AccumuloRowSerializer.getBlockFromMap(type, expected)); Map actual = serializer.decode(type, data); diff --git a/presto-array/pom.xml b/presto-array/pom.xml index 7bfe68580bd24..1659d5845e4b5 100644 --- a/presto-array/pom.xml +++ b/presto-array/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-array @@ -21,6 +21,11 @@ slice + + it.unimi.dsi + fastutil + + com.facebook.presto presto-spi @@ -30,5 +35,12 @@ org.openjdk.jol jol-core + + + + org.testng + testng + test + diff --git a/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java index 24d75ea1d21e7..f6bb1106315b8 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/BlockBigArray.java @@ -14,10 +14,13 @@ package com.facebook.presto.array; import com.facebook.presto.spi.block.Block; +import org.openjdk.jol.info.ClassLayout; public final class BlockBigArray { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(BlockBigArray.class).instanceSize(); private final ObjectBigArray array; + private final ReferenceCountMap trackedObjects = new ReferenceCountMap(); private long sizeOfBlocks; public BlockBigArray() @@ -35,7 +38,7 @@ public BlockBigArray(Block block) */ public long sizeOf() { - return array.sizeOf() + sizeOfBlocks; + return INSTANCE_SIZE + array.sizeOf() + sizeOfBlocks + trackedObjects.sizeOf(); } /** @@ -58,10 +61,30 @@ public void set(long index, Block value) { Block currentValue = array.get(index); if (currentValue != null) { - sizeOfBlocks -= currentValue.getRetainedSizeInBytes(); + currentValue.retainedBytesForEachPart((object, size) -> { + if (currentValue == object) { + // track instance size separately as the reference count for an instance is always 1 + sizeOfBlocks -= size; + return; + } + if (trackedObjects.decrementReference(object) == 0) { + // decrement the size only when it is the last reference + sizeOfBlocks -= size; + } + }); } if (value != null) { - sizeOfBlocks += value.getRetainedSizeInBytes(); + value.retainedBytesForEachPart((object, size) -> { + if (value == object) { + // track instance size separately as the reference count for an instance is always 1 + sizeOfBlocks += size; + return; + } + if (trackedObjects.incrementReference(object) == 1) { + // increment the size only when it is the first reference + sizeOfBlocks += size; + } + }); } array.set(index, value); } diff --git a/presto-array/src/main/java/com/facebook/presto/array/BooleanBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/BooleanBigArray.java index 0a85ac7e7be34..a46881784d22d 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/BooleanBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/BooleanBigArray.java @@ -14,6 +14,7 @@ package com.facebook.presto.array; import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; @@ -27,6 +28,7 @@ // Copyright (C) 2010-2013 Sebastiano Vigna public final class BooleanBigArray { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(BooleanBigArray.class).instanceSize(); private static final long SIZE_OF_SEGMENT = sizeOfBooleanArray(SEGMENT_SIZE); private final boolean initialValue; @@ -55,7 +57,7 @@ public BooleanBigArray(boolean initialValue) */ public long sizeOf() { - return SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); + return INSTANCE_SIZE + SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); } /** diff --git a/presto-array/src/main/java/com/facebook/presto/array/ByteBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/ByteBigArray.java index 72a241855b80b..e9745ff10f2de 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/ByteBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/ByteBigArray.java @@ -14,6 +14,7 @@ package com.facebook.presto.array; import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; @@ -27,6 +28,7 @@ // Copyright (C) 2010-2013 Sebastiano Vigna public final class ByteBigArray { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ByteBigArray.class).instanceSize(); private static final long SIZE_OF_SEGMENT = sizeOfByteArray(SEGMENT_SIZE); private final byte initialValue; @@ -55,7 +57,7 @@ public ByteBigArray(byte initialValue) */ public long sizeOf() { - return SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); + return INSTANCE_SIZE + SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); } /** diff --git a/presto-array/src/main/java/com/facebook/presto/array/DoubleBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/DoubleBigArray.java index f0157d5f41d6b..a4b14db86cafb 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/DoubleBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/DoubleBigArray.java @@ -14,6 +14,7 @@ package com.facebook.presto.array; import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; @@ -27,6 +28,7 @@ // Copyright (C) 2010-2013 Sebastiano Vigna public final class DoubleBigArray { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(DoubleBigArray.class).instanceSize(); private static final long SIZE_OF_SEGMENT = sizeOfDoubleArray(SEGMENT_SIZE); private final double initialValue; @@ -58,7 +60,7 @@ public DoubleBigArray(double initialValue) */ public long sizeOf() { - return SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); + return INSTANCE_SIZE + SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); } /** diff --git a/presto-array/src/main/java/com/facebook/presto/array/IntBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/IntBigArray.java index dc94b924ee956..deefed913c142 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/IntBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/IntBigArray.java @@ -14,6 +14,7 @@ package com.facebook.presto.array; import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; @@ -21,13 +22,14 @@ import static com.facebook.presto.array.BigArrays.SEGMENT_SIZE; import static com.facebook.presto.array.BigArrays.offset; import static com.facebook.presto.array.BigArrays.segment; -import static io.airlift.slice.SizeOf.sizeOfLongArray; +import static io.airlift.slice.SizeOf.sizeOfIntArray; // Note: this code was forked from fastutil (http://fastutil.di.unimi.it/) // Copyright (C) 2010-2013 Sebastiano Vigna public final class IntBigArray { - private static final long SIZE_OF_SEGMENT = sizeOfLongArray(SEGMENT_SIZE); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(IntBigArray.class).instanceSize(); + private static final long SIZE_OF_SEGMENT = sizeOfIntArray(SEGMENT_SIZE); private final int initialValue; @@ -58,7 +60,7 @@ public IntBigArray(int initialValue) */ public long sizeOf() { - return SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); + return INSTANCE_SIZE + SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); } /** diff --git a/presto-array/src/main/java/com/facebook/presto/array/LongBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/LongBigArray.java index a55c3754ae02f..b5e49beedcfa7 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/LongBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/LongBigArray.java @@ -14,6 +14,7 @@ package com.facebook.presto.array; import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; @@ -27,6 +28,7 @@ // Copyright (C) 2010-2013 Sebastiano Vigna public final class LongBigArray { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(LongBigArray.class).instanceSize(); private static final long SIZE_OF_SEGMENT = sizeOfLongArray(SEGMENT_SIZE); private final long initialValue; @@ -58,7 +60,7 @@ public LongBigArray(long initialValue) */ public long sizeOf() { - return SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); + return INSTANCE_SIZE + SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); } /** diff --git a/presto-array/src/main/java/com/facebook/presto/array/ObjectBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/ObjectBigArray.java index 156203c07f35c..c88d1a1eae890 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/ObjectBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/ObjectBigArray.java @@ -14,6 +14,7 @@ package com.facebook.presto.array; import io.airlift.slice.SizeOf; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; @@ -27,6 +28,7 @@ // Copyright (C) 2010-2013 Sebastiano Vigna public final class ObjectBigArray { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ObjectBigArray.class).instanceSize(); private static final long SIZE_OF_SEGMENT = sizeOfObjectArray(SEGMENT_SIZE); private final Object initialValue; @@ -55,7 +57,7 @@ public ObjectBigArray(Object initialValue) */ public long sizeOf() { - return SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); + return INSTANCE_SIZE + SizeOf.sizeOf(array) + (segments * SIZE_OF_SEGMENT); } /** diff --git a/presto-array/src/main/java/com/facebook/presto/array/ReferenceCountMap.java b/presto-array/src/main/java/com/facebook/presto/array/ReferenceCountMap.java new file mode 100644 index 0000000000000..b0c2130442acb --- /dev/null +++ b/presto-array/src/main/java/com/facebook/presto/array/ReferenceCountMap.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.array; + +import io.airlift.slice.SizeOf; +import it.unimi.dsi.fastutil.objects.Object2IntOpenCustomHashMap; +import org.openjdk.jol.info.ClassLayout; + +public final class ReferenceCountMap + extends Object2IntOpenCustomHashMap +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ReferenceCountMap.class).instanceSize(); + + /** + * Two different blocks can share the same underlying data + * Use the map to avoid memory over counting + */ + public ReferenceCountMap() + { + super(new ObjectStrategy()); + } + + /** + * Increments the reference count of an object by 1 and returns the updated reference count + */ + public int incrementReference(Object key) + { + return addTo(key, 1) + 1; + } + + /** + * Decrements the reference count of an object by 1 and returns the updated reference count + */ + public int decrementReference(Object key) + { + int previousCount = addTo(key, -1); + if (previousCount == 1) { + remove(key); + } + return previousCount - 1; + } + + /** + * Returns the size of this map in bytes. + */ + public long sizeOf() + { + return INSTANCE_SIZE + SizeOf.sizeOf(key) + SizeOf.sizeOf(value) + SizeOf.sizeOf(used); + } + + private static final class ObjectStrategy + implements Strategy + { + @Override + public int hashCode(Object object) + { + return System.identityHashCode(object); + } + + @Override + public boolean equals(Object left, Object right) + { + return left == right; + } + } +} diff --git a/presto-array/src/main/java/com/facebook/presto/array/SliceBigArray.java b/presto-array/src/main/java/com/facebook/presto/array/SliceBigArray.java index e757d6eb6a2df..6b620f4d67839 100644 --- a/presto-array/src/main/java/com/facebook/presto/array/SliceBigArray.java +++ b/presto-array/src/main/java/com/facebook/presto/array/SliceBigArray.java @@ -18,6 +18,7 @@ public final class SliceBigArray { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceBigArray.class).instanceSize(); private static final int SLICE_INSTANCE_SIZE = ClassLayout.parseClass(Slice.class).instanceSize(); private final ObjectBigArray array; private long sizeOfSlices; @@ -37,7 +38,7 @@ public SliceBigArray(Slice slice) */ public long sizeOf() { - return array.sizeOf() + sizeOfSlices; + return INSTANCE_SIZE + array.sizeOf() + sizeOfSlices; } /** diff --git a/presto-array/src/test/java/com/facebook/presto/array/TestBlockBigArray.java b/presto-array/src/test/java/com/facebook/presto/array/TestBlockBigArray.java new file mode 100644 index 0000000000000..60d0a0f826c52 --- /dev/null +++ b/presto-array/src/test/java/com/facebook/presto/array/TestBlockBigArray.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.array; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.block.IntArrayBlockBuilder; +import org.openjdk.jol.info.ClassLayout; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestBlockBigArray +{ + @Test + public void testRetainedSizeWithOverlappingBlocks() + { + int entries = 123; + BlockBuilder blockBuilder = new IntArrayBlockBuilder(new BlockBuilderStatus(), entries); + for (int i = 0; i < entries; i++) { + blockBuilder.writeInt(i); + } + Block block = blockBuilder.build(); + + // Verify we do not over count + int arraySize = 456; + int blocks = 7890; + BlockBigArray blockBigArray = new BlockBigArray(); + blockBigArray.ensureCapacity(arraySize); + for (int i = 0; i < blocks; i++) { + blockBigArray.set(i % arraySize, block.getRegion(0, entries)); + } + + ReferenceCountMap referenceCountMap = new ReferenceCountMap(); + referenceCountMap.incrementReference(block); + long expectedSize = ClassLayout.parseClass(BlockBigArray.class).instanceSize() + + referenceCountMap.sizeOf() + + (new ObjectBigArray()).sizeOf() + + block.getRetainedSizeInBytes() + (arraySize - 1) * ClassLayout.parseClass(block.getClass()).instanceSize(); + assertEquals(blockBigArray.sizeOf(), expectedSize); + } +} diff --git a/presto-atop/pom.xml b/presto-atop/pom.xml index b4afb5de4237c..c0384cd703970 100644 --- a/presto-atop/pom.xml +++ b/presto-atop/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-atop diff --git a/presto-base-jdbc/pom.xml b/presto-base-jdbc/pom.xml index 632d757d10a08..1e9275e84285e 100644 --- a/presto-base-jdbc/pom.xml +++ b/presto-base-jdbc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-base-jdbc diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java index 49b08005ee9e6..5725721aab7ad 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java @@ -14,6 +14,7 @@ package com.facebook.presto.plugin.jdbc; import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigSecuritySensitive; import javax.validation.constraints.NotNull; @@ -54,6 +55,7 @@ public String getConnectionPassword() } @Config("connection-password") + @ConfigSecuritySensitive public BaseJdbcConfig setConnectionPassword(String connectionPassword) { this.connectionPassword = connectionPassword; diff --git a/presto-benchmark-driver/pom.xml b/presto-benchmark-driver/pom.xml index 188bfcc22baf9..5da62b467f74e 100644 --- a/presto-benchmark-driver/pom.xml +++ b/presto-benchmark-driver/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-benchmark-driver @@ -72,6 +72,11 @@ commons-math3 + + com.squareup.okhttp3 + okhttp + + org.testng diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java index 0685b71820cb6..7e97b21ade278 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java @@ -110,7 +110,7 @@ private static URI parseServer(String server) HostAndPort host = HostAndPort.fromString(server); try { - return new URI("http", null, host.getHostText(), host.getPortOrDefault(80), null, null, null); + return new URI("http", null, host.getHost(), host.getPortOrDefault(80), null, null, null); } catch (URISyntaxException e) { throw new IllegalArgumentException(e); diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java index 2ca469875c643..1f82be9d18f22 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java @@ -15,7 +15,6 @@ import com.facebook.presto.client.ClientSession; import com.facebook.presto.client.QueryError; -import com.facebook.presto.client.QueryResults; import com.facebook.presto.client.StatementClient; import com.facebook.presto.client.StatementStats; import com.google.common.base.Throwables; @@ -28,8 +27,8 @@ import io.airlift.http.client.JsonResponseHandler; import io.airlift.http.client.Request; import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.json.JsonCodec; import io.airlift.units.Duration; +import okhttp3.OkHttpClient; import java.io.Closeable; import java.net.URI; @@ -39,6 +38,7 @@ import static com.facebook.presto.benchmark.driver.BenchmarkQueryResult.failResult; import static com.facebook.presto.benchmark.driver.BenchmarkQueryResult.passResult; +import static com.facebook.presto.client.OkHttpUtil.setupSocksProxy; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; @@ -59,8 +59,8 @@ public class BenchmarkQueryRunner private final int maxFailures; private final HttpClient httpClient; + private final OkHttpClient okHttpClient; private final List nodes; - private final JsonCodec queryResultsCodec; private int failures; @@ -77,8 +77,6 @@ public BenchmarkQueryRunner(int warm, int runs, boolean debug, int maxFailures, this.debug = debug; - this.queryResultsCodec = jsonCodec(QueryResults.class); - requireNonNull(socksProxy, "socksProxy is null"); HttpClientConfig httpClientConfig = new HttpClientConfig(); if (socksProxy.isPresent()) { @@ -87,6 +85,10 @@ public BenchmarkQueryRunner(int warm, int runs, boolean debug, int maxFailures, this.httpClient = new JettyHttpClient(httpClientConfig.setConnectTimeout(new Duration(10, TimeUnit.SECONDS))); + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + setupSocksProxy(builder, socksProxy); + this.okHttpClient = builder.build(); + nodes = getAllNodes(requireNonNull(serverUri, "serverUri is null")); } @@ -149,7 +151,7 @@ public List getSchemas(ClientSession session) failures = 0; while (true) { // start query - StatementClient client = new StatementClient(httpClient, queryResultsCodec, session, "show schemas"); + StatementClient client = new StatementClient(okHttpClient, session, "show schemas"); // read query output ImmutableList.Builder schemas = ImmutableList.builder(); @@ -190,7 +192,7 @@ public List getSchemas(ClientSession session) private StatementStats execute(ClientSession session, String name, String query) { // start query - StatementClient client = new StatementClient(httpClient, queryResultsCodec, session, query); + StatementClient client = new StatementClient(okHttpClient, session, query); // read query output while (client.isValid() && client.advance()) { diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java index 8d15446a3e19b..282236e91ab7b 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/Suite.java @@ -26,11 +26,11 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.regex.Pattern; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; import static io.airlift.json.JsonCodec.mapJsonCodec; import static java.util.Objects.requireNonNull; @@ -93,11 +93,11 @@ public List selectQueries(Iterable queries) return ImmutableList.copyOf(queries); } - List filteredQueries = StreamSupport.stream(queries.spliterator(), false) + List filteredQueries = stream(queries) .filter(query -> getQueryNamePatterns().stream().anyMatch(pattern -> pattern.matcher(query.getName()).matches())) - .collect(Collectors.toList()); + .collect(toImmutableList()); - return ImmutableList.copyOf(filteredQueries); + return filteredQueries; } @Override diff --git a/presto-benchmark/pom.xml b/presto-benchmark/pom.xml index 24d0a214a70a4..906fb82e54549 100644 --- a/presto-benchmark/pom.xml +++ b/presto-benchmark/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-benchmark diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java index f129fab0ac087..f931743a55bd8 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java @@ -123,10 +123,13 @@ public void runBenchmark(@Nullable BenchmarkResultHook benchmarkResultHook) long outputRows = resultsAvg.get("output_rows").longValue(); DataSize outputBytes = new DataSize(resultsAvg.get("output_bytes"), BYTE); - System.out.printf("%35s :: %8.3f cpu ms :: in %5s, %6s, %8s, %8s :: out %5s, %6s, %8s, %8s%n", + DataSize memory = new DataSize(resultsAvg.get("peak_memory"), BYTE); + System.out.printf("%35s :: %8.3f cpu ms :: %5s peak memory :: in %5s, %6s, %8s, %8s :: out %5s, %6s, %8s, %8s%n", getBenchmarkName(), cpuNanos.getValue(MILLISECONDS), + formatDataSize(memory, true), + formatCount(inputRows), formatDataSize(inputBytes, true), formatCountRate(inputRows, cpuNanos, true), diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java index 9dd25dfb59b46..4d3fa50b2c6fa 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java @@ -101,21 +101,28 @@ protected OperatorFactory createHashProjectOperator(int operatorId, PlanNodeId p protected abstract List createDrivers(TaskContext taskContext); - protected void execute(TaskContext taskContext) + protected Map execute(TaskContext taskContext) { List drivers = createDrivers(taskContext); + long peakMemory = 0; boolean done = false; while (!done) { boolean processed = false; for (Driver driver : drivers) { if (!driver.isFinished()) { driver.process(); + long lastPeakMemory = peakMemory; + peakMemory = (long) taskContext.getTaskStats().getMemoryReservation().getValue(BYTE); + if (peakMemory <= lastPeakMemory) { + peakMemory = lastPeakMemory; + } processed = true; } } done = !processed; } + return ImmutableMap.of("peak_memory", peakMemory); } @Override @@ -136,7 +143,7 @@ protected Map runOnce() false); CpuTimer cpuTimer = new CpuTimer(); - execute(taskContext); + Map executionStats = execute(taskContext); CpuDuration executionTime = cpuTimer.elapsedTime(); TaskStats taskStats = taskContext.getTaskStats(); @@ -149,6 +156,7 @@ protected Map runOnce() return ImmutableMap.builder() // legacy computed values + .putAll(executionStats) .put("elapsed_millis", executionTime.getWall().toMillis()) .put("input_rows_per_second", (long) (inputRows / executionTime.getWall().getValue(SECONDS))) .put("output_rows_per_second", (long) (outputRows / executionTime.getWall().getValue(SECONDS))) diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/ArrayAggregationBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/ArrayAggregationBenchmark.java new file mode 100644 index 0000000000000..13a2ad76e21ba --- /dev/null +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/ArrayAggregationBenchmark.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.benchmark; + +import com.facebook.presto.testing.LocalQueryRunner; + +import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; + +public class ArrayAggregationBenchmark + extends AbstractSqlBenchmark +{ + public ArrayAggregationBenchmark(LocalQueryRunner localQueryRunner) + { + super(localQueryRunner, "sql_double_array_agg", 10, 100, "select array_agg(totalprice) from orders group by orderkey"); + } + + public static void main(String[] args) + { + new ArrayAggregationBenchmark(createLocalQueryRunner()).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out)); + } +} diff --git a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java index e6136949498f8..4f46cbdba4293 100644 --- a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java +++ b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java @@ -59,7 +59,7 @@ public static class Context private MemoryLocalQueryRunner queryRunner; @Param({"true", "false"}) - private String fastInequalityJoin; + private String fastInequalityJoins; // number of buckets. The smaller number of buckets, the longer position links chain @Param({"100", "1000", "10000", "60000"}) @@ -78,7 +78,7 @@ public MemoryLocalQueryRunner getQueryRunner() @Setup public void setUp() { - queryRunner = new MemoryLocalQueryRunner(ImmutableMap.of(SystemSessionProperties.FAST_INEQUALITY_JOIN, fastInequalityJoin)); + queryRunner = new MemoryLocalQueryRunner(ImmutableMap.of(SystemSessionProperties.FAST_INEQUALITY_JOINS, fastInequalityJoins)); // t1.val1 is in range [0, 1000) // t1.bucket is in [0, 1000) diff --git a/presto-benchto-benchmarks/pom.xml b/presto-benchto-benchmarks/pom.xml index 2af15a723f623..00a486eb29c58 100644 --- a/presto-benchto-benchmarks/pom.xml +++ b/presto-benchto-benchmarks/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-benchto-benchmarks diff --git a/presto-blackhole/pom.xml b/presto-blackhole/pom.xml index aead398ac4bbd..a931ab46f4f05 100644 --- a/presto-blackhole/pom.xml +++ b/presto-blackhole/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-blackhole diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSourceProvider.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSourceProvider.java index 86bec9bbeb40c..abf1e2cdd75ae 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSourceProvider.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSourceProvider.java @@ -145,7 +145,7 @@ else if (javaType == Slice.class) { private boolean isSupportedType(Type type) { - return ImmutableSet.of(TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, DATE, TIMESTAMP, VARBINARY).contains(type) + return ImmutableSet.of(TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, DATE, TIMESTAMP, VARBINARY).contains(type) || isVarcharType(type) || isLongDecimal(type) || isShortDecimal(type); } } diff --git a/presto-bytecode/pom.xml b/presto-bytecode/pom.xml index e0c248f5b0658..5630686676dac 100644 --- a/presto-bytecode/pom.xml +++ b/presto-bytecode/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-bytecode diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java index 3ea1903c52ff4..ad94f32a87ce4 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Set; +import static com.facebook.presto.bytecode.Access.INTERFACE; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.Access.toAccessModifier; @@ -64,7 +65,7 @@ public ClassDefinition( ParameterizedType... interfaces) { requireNonNull(access, "access is null"); - requireNonNull(access, "access is null"); + requireNonNull(type, "type is null"); requireNonNull(superClass, "superClass is null"); requireNonNull(interfaces, "interfaces is null"); @@ -121,6 +122,11 @@ public List getMethods() return ImmutableList.copyOf(methods); } + public boolean isInterface() + { + return access.contains(INTERFACE); + } + public void visit(ClassVisitor visitor) { // Generic signature if super class or any interface is generic @@ -133,7 +139,8 @@ public void visit(ClassVisitor visitor) for (int i = 0; i < interfaces.length; i++) { interfaces[i] = this.interfaces.get(i).getClassName(); } - visitor.visit(V1_7, toAccessModifier(access) | ACC_SUPER, type.getClassName(), signature, superClass.getClassName(), interfaces); + int accessModifier = toAccessModifier(access); + visitor.visit(V1_7, isInterface() ? accessModifier : accessModifier | ACC_SUPER, type.getClassName(), signature, superClass.getClassName(), interfaces); // visit source if (source != null) { @@ -151,7 +158,9 @@ public void visit(ClassVisitor visitor) } // visit clinit method - classInitializer.visit(visitor, true); + if (!isInterface()) { + classInitializer.visit(visitor, true); + } // visit methods for (MethodDefinition method : methods) { @@ -210,6 +219,9 @@ public ClassDefinition addField(FieldDefinition field) public MethodDefinition getClassInitializer() { + if (isInterface()) { + throw new IllegalAccessError("Interface does not have class initializer"); + } return classInitializer; } diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/DumpBytecodeVisitor.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/DumpBytecodeVisitor.java index 0062e073fa24b..9e0655275b382 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/DumpBytecodeVisitor.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/DumpBytecodeVisitor.java @@ -50,6 +50,7 @@ import java.util.Collection; import java.util.List; +import static com.facebook.presto.bytecode.Access.INTERFACE; import static com.facebook.presto.bytecode.ParameterizedType.type; public class DumpBytecodeVisitor @@ -72,7 +73,11 @@ public Void visitClass(ClassDefinition classDefinition) } // print class declaration - Line classDeclaration = line().addAll(classDefinition.getAccess()).add("class").add(classDefinition.getType().getJavaClassName()); + Line classDeclaration = line().addAll(classDefinition.getAccess()); + if (!classDefinition.getAccess().contains(INTERFACE)) { + classDeclaration.add("class"); + } + classDeclaration.add(classDefinition.getType().getJavaClassName()); if (!classDefinition.getSuperClass().equals(type(Object.class))) { classDeclaration.add("extends").add(classDefinition.getSuperClass().getJavaClassName()); } @@ -98,6 +103,9 @@ public Void visitClass(ClassDefinition classDefinition) visitMethod(classDefinition, methodDefinition); } + // print class initializer + visitMethod(classDefinition, classDefinition.getClassInitializer()); + indentLevel--; printLine("}"); printLine(); diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java index 5b651937091c5..0d0c50b015af8 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java @@ -84,9 +84,8 @@ public MethodDefinition( this.parameters = ImmutableList.copyOf(parameters); this.parameterTypes = Lists.transform(this.parameters, Parameter::getType); this.parameterAnnotations = ImmutableList.copyOf(transform(parameters, input -> new ArrayList<>())); - Optional thisType = Optional.empty(); - if (!access.contains(STATIC)) { + if (!declaringClass.isInterface() && !access.contains(STATIC)) { thisType = Optional.of(declaringClass.getType()); } scope = new Scope(thisType, parameters); @@ -171,6 +170,9 @@ public String getMethodDescriptor() public BytecodeBlock getBody() { + if (declaringClass.isInterface()) { + throw new IllegalAccessError("Interface does not have method body"); + } return body; } @@ -236,19 +238,19 @@ public void visit(ClassVisitor visitor, boolean addReturn) parameterAnnotation.visitParameterAnnotation(parameterIndex, methodVisitor); } } - - // visit code - methodVisitor.visitCode(); - - // visit instructions - MethodGenerationContext generationContext = new MethodGenerationContext(methodVisitor); - generationContext.enterScope(scope); - body.accept(methodVisitor, generationContext); - if (addReturn) { - new InsnNode(RETURN).accept(methodVisitor); + if (!declaringClass.isInterface()) { + // visit code + methodVisitor.visitCode(); + + // visit instructions + MethodGenerationContext generationContext = new MethodGenerationContext(methodVisitor); + generationContext.enterScope(scope); + body.accept(methodVisitor, generationContext); + if (addReturn) { + new InsnNode(RETURN).accept(methodVisitor); + } + generationContext.exitScope(scope); } - generationContext.exitScope(scope); - // done methodVisitor.visitMaxs(-1, -1); methodVisitor.visitEnd(); diff --git a/presto-cassandra/pom.xml b/presto-cassandra/pom.xml index bd8302c718cab..705b52946058f 100644 --- a/presto-cassandra/pom.xml +++ b/presto-cassandra/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-cassandra diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CachingCassandraSchemaProvider.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CachingCassandraSchemaProvider.java deleted file mode 100644 index 92a637e8bb794..0000000000000 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CachingCassandraSchemaProvider.java +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.cassandra; - -import com.facebook.presto.spi.NotFoundException; -import com.facebook.presto.spi.SchemaNotFoundException; -import com.facebook.presto.spi.SchemaTableName; -import com.facebook.presto.spi.TableNotFoundException; -import com.google.common.base.Throwables; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Maps; -import com.google.common.util.concurrent.UncheckedExecutionException; -import io.airlift.units.Duration; -import org.weakref.jmx.Managed; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; - -import static com.facebook.presto.cassandra.RetryDriver.retry; -import static com.google.common.cache.CacheLoader.asyncReloading; -import static java.util.Locale.ENGLISH; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; - -/** - * Cassandra Schema Cache - */ -@ThreadSafe -public class CachingCassandraSchemaProvider -{ - private final String connectorId; - private final CassandraSession session; - - /** - * Mapping from an empty string to all schema names. Each schema name is a - * mapping from the lower case schema name to the case sensitive schema name. - * This mapping is necessary because Presto currently does not properly handle - * case sensitive names. - */ - private final LoadingCache> schemaNamesCache; - - /** - * Mapping from lower case schema name to all tables in that schema. Each - * table name is a mapping from the lower case table name to the case - * sensitive table name. This mapping is necessary because Presto currently - * does not properly handle case sensitive names. - */ - private final LoadingCache> tableNamesCache; - private final LoadingCache tableCache; - - @Inject - public CachingCassandraSchemaProvider( - CassandraConnectorId connectorId, - CassandraSession session, - @ForCassandra ExecutorService executor, - CassandraClientConfig cassandraClientConfig) - { - this(requireNonNull(connectorId, "connectorId is null").toString(), - session, - executor, - requireNonNull(cassandraClientConfig, "cassandraClientConfig is null").getSchemaCacheTtl(), - cassandraClientConfig.getSchemaRefreshInterval()); - } - - public CachingCassandraSchemaProvider(String connectorId, CassandraSession session, ExecutorService executor, Duration cacheTtl, Duration refreshInterval) - { - this.connectorId = requireNonNull(connectorId, "connectorId is null"); - this.session = requireNonNull(session, "cassandraSession is null"); - - requireNonNull(executor, "executor is null"); - - long expiresAfterWriteMillis = requireNonNull(cacheTtl, "cacheTtl is null").toMillis(); - long refreshMills = requireNonNull(refreshInterval, "refreshInterval is null").toMillis(); - - schemaNamesCache = CacheBuilder.newBuilder() - .expireAfterWrite(expiresAfterWriteMillis, MILLISECONDS) - .refreshAfterWrite(refreshMills, MILLISECONDS) - .build(asyncReloading(new CacheLoader>() - { - @Override - public Map load(String key) - throws Exception - { - return loadAllSchemas(); - } - }, executor)); - - tableNamesCache = CacheBuilder.newBuilder() - .expireAfterWrite(expiresAfterWriteMillis, MILLISECONDS) - .refreshAfterWrite(refreshMills, MILLISECONDS) - .build(asyncReloading(new CacheLoader>() - { - @Override - public Map load(String databaseName) - throws Exception - { - return loadAllTables(databaseName); - } - }, executor)); - - tableCache = CacheBuilder.newBuilder() - .expireAfterWrite(expiresAfterWriteMillis, MILLISECONDS) - .refreshAfterWrite(refreshMills, MILLISECONDS) - .build(asyncReloading(new CacheLoader() - { - @Override - public CassandraTable load(SchemaTableName tableName) - throws Exception - { - return loadTable(tableName); - } - }, executor)); - } - - @Managed - public void flushCache() - { - schemaNamesCache.invalidateAll(); - tableNamesCache.invalidateAll(); - tableCache.invalidateAll(); - } - - public List getAllSchemas() - { - return ImmutableList.copyOf(getCacheValue(schemaNamesCache, "", RuntimeException.class).keySet()); - } - - private Map loadAllSchemas() - throws Exception - { - return retry() - .stopOnIllegalExceptions() - .run("getAllSchemas", () -> Maps.uniqueIndex(session.getAllSchemas(), CachingCassandraSchemaProvider::toLowerCase)); - } - - public List getAllTables(String databaseName) - throws SchemaNotFoundException - { - return ImmutableList.copyOf(getCacheValue(tableNamesCache, databaseName, SchemaNotFoundException.class).keySet()); - } - - private Map loadAllTables(final String databaseName) - throws Exception - { - return retry().stopOn(NotFoundException.class).stopOnIllegalExceptions() - .run("getAllTables", () -> { - String caseSensitiveDatabaseName = getCaseSensitiveSchemaName(databaseName); - if (caseSensitiveDatabaseName == null) { - caseSensitiveDatabaseName = databaseName; - } - List tables = session.getAllTables(caseSensitiveDatabaseName); - Map nameMap = Maps.uniqueIndex(tables, CachingCassandraSchemaProvider::toLowerCase); - - if (tables.isEmpty()) { - // Check to see if the database exists - session.getSchema(databaseName); - } - return nameMap; - }); - } - - public CassandraTableHandle getTableHandle(SchemaTableName schemaTableName) - { - requireNonNull(schemaTableName, "schemaTableName is null"); - String schemaName = getCaseSensitiveSchemaName(schemaTableName.getSchemaName()); - String tableName = getCaseSensitiveTableName(schemaTableName); - CassandraTableHandle tableHandle = new CassandraTableHandle(connectorId, schemaName, tableName); - return tableHandle; - } - - public String getCaseSensitiveSchemaName(String caseInsensitiveName) - { - String caseSensitiveSchemaName = getCacheValue(schemaNamesCache, "", RuntimeException.class).get(caseInsensitiveName.toLowerCase(ENGLISH)); - return caseSensitiveSchemaName == null ? caseInsensitiveName : caseSensitiveSchemaName; - } - - public String getCaseSensitiveTableName(SchemaTableName schemaTableName) - { - String caseSensitiveTableName = getCacheValue(tableNamesCache, schemaTableName.getSchemaName(), SchemaNotFoundException.class).get(schemaTableName.getTableName().toLowerCase(ENGLISH)); - return caseSensitiveTableName == null ? schemaTableName.getTableName() : caseSensitiveTableName; - } - - public CassandraTable getTable(CassandraTableHandle tableHandle) - throws TableNotFoundException - { - return getCacheValue(tableCache, tableHandle.getSchemaTableName(), TableNotFoundException.class); - } - - public void flushTable(SchemaTableName tableName) - { - tableCache.invalidate(tableName); - tableNamesCache.invalidate(tableName.getSchemaName()); - schemaNamesCache.invalidateAll(); - } - - private CassandraTable loadTable(final SchemaTableName tableName) - throws Exception - { - return retry() - .stopOn(NotFoundException.class) - .stopOnIllegalExceptions() - .run("getTable", () -> session.getTable(tableName)); - } - - private static V getCacheValue(LoadingCache cache, K key, Class exceptionClass) - throws E - { - try { - return cache.get(key); - } - catch (ExecutionException | UncheckedExecutionException e) { - Throwable t = e.getCause(); - Throwables.propagateIfInstanceOf(t, exceptionClass); - throw Throwables.propagate(t); - } - } - - private static String toLowerCase(String value) - { - return value.toLowerCase(ENGLISH); - } -} diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java index 40c5fd1f58b00..73a8a321df107 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableList; import io.airlift.configuration.Config; import io.airlift.configuration.ConfigDescription; +import io.airlift.configuration.ConfigSecuritySensitive; import io.airlift.configuration.DefunctConfig; import io.airlift.units.Duration; import io.airlift.units.MaxDuration; @@ -30,20 +31,17 @@ import java.util.Arrays; import java.util.List; -import java.util.concurrent.TimeUnit; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; @DefunctConfig({"cassandra.thrift-port", "cassandra.partitioner", "cassandra.thrift-connection-factory-class", "cassandra.transport-factory-options", - "cassandra.no-host-available-retry-count"}) + "cassandra.no-host-available-retry-count", "cassandra.max-schema-refresh-threads", "cassandra.schema-cache-ttl", + "cassandra.schema-refresh-interval"}) public class CassandraClientConfig { private static final Splitter SPLITTER = Splitter.on(',').trimResults().omitEmptyStrings(); - private Duration schemaCacheTtl = new Duration(1, TimeUnit.HOURS); - private Duration schemaRefreshInterval = new Duration(2, TimeUnit.MINUTES); - private int maxSchemaRefreshThreads = 1; private ConsistencyLevel consistencyLevel = ConsistencyLevel.ONE; private int fetchSize = 5_000; private List contactPoints = ImmutableList.of(); @@ -69,45 +67,6 @@ public class CassandraClientConfig private int speculativeExecutionLimit = 1; private Duration speculativeExecutionDelay = new Duration(500, MILLISECONDS); - @Min(1) - public int getMaxSchemaRefreshThreads() - { - return maxSchemaRefreshThreads; - } - - @Config("cassandra.max-schema-refresh-threads") - public CassandraClientConfig setMaxSchemaRefreshThreads(int maxSchemaRefreshThreads) - { - this.maxSchemaRefreshThreads = maxSchemaRefreshThreads; - return this; - } - - @NotNull - public Duration getSchemaCacheTtl() - { - return schemaCacheTtl; - } - - @Config("cassandra.schema-cache-ttl") - public CassandraClientConfig setSchemaCacheTtl(Duration schemaCacheTtl) - { - this.schemaCacheTtl = schemaCacheTtl; - return this; - } - - @NotNull - public Duration getSchemaRefreshInterval() - { - return schemaRefreshInterval; - } - - @Config("cassandra.schema-refresh-interval") - public CassandraClientConfig setSchemaRefreshInterval(Duration schemaRefreshInterval) - { - this.schemaRefreshInterval = schemaRefreshInterval; - return this; - } - @NotNull @Size(min = 1) public List getContactPoints() @@ -224,6 +183,7 @@ public String getPassword() } @Config("cassandra.password") + @ConfigSecuritySensitive public CassandraClientConfig setPassword(String password) { this.password = password; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java index d1cb67f7eee4c..141ab37d9d30d 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java @@ -35,17 +35,12 @@ import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.ExecutorService; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.configuration.ConfigBinder.configBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; -import static java.util.concurrent.Executors.newFixedThreadPool; -import static org.weakref.jmx.ObjectNames.generatedNameOf; -import static org.weakref.jmx.guice.ExportBinder.newExporter; public class CassandraClientModule implements Module @@ -71,22 +66,9 @@ public void configure(Binder binder) configBinder(binder).bindConfig(CassandraClientConfig.class); - binder.bind(CachingCassandraSchemaProvider.class).in(Scopes.SINGLETON); - newExporter(binder).export(CachingCassandraSchemaProvider.class).as(generatedNameOf(CachingCassandraSchemaProvider.class, connectorId)); - jsonCodecBinder(binder).bindListJsonCodec(ExtraColumnMetadata.class); } - @ForCassandra - @Singleton - @Provides - public static ExecutorService createCachingCassandraSchemaExecutor(CassandraConnectorId clientId, CassandraClientConfig cassandraClientConfig) - { - return newFixedThreadPool( - cassandraClientConfig.getMaxSchemaRefreshThreads(), - daemonThreadsNamed("cassandra-" + clientId + "-%s")); - } - @Singleton @Provides public static CassandraSession createCassandraSession( @@ -161,6 +143,10 @@ public static CassandraSession createCassandraSession( )); } - return new NativeCassandraSession(connectorId.toString(), extraColumnMetadataCodec, clusterBuilder.build(), config.getNoHostAvailableRetryTimeout()); + return new NativeCassandraSession( + connectorId.toString(), + extraColumnMetadataCodec, + new ReopeningCluster(clusterBuilder::build), + config.getNoHostAvailableRetryTimeout()); } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java index eecbbf6ae2f05..6693baff17b0f 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClusteringPredicatesExtractor.java @@ -13,66 +13,41 @@ */ package com.facebook.presto.cassandra; +import com.datastax.driver.core.VersionNumber; import com.facebook.presto.cassandra.util.CassandraCqlUtils; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; +import com.google.common.base.Joiner; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import static com.facebook.presto.cassandra.util.CassandraCqlUtils.toCQLCompatibleString; -import static com.google.common.collect.Sets.cartesianProduct; import static java.util.Objects.requireNonNull; public class CassandraClusteringPredicatesExtractor { private final List clusteringColumns; - private final TupleDomain predicates; private final ClusteringPushDownResult clusteringPushDownResult; + private final TupleDomain predicates; - public CassandraClusteringPredicatesExtractor(List clusteringColumns, TupleDomain predicates) + public CassandraClusteringPredicatesExtractor(List clusteringColumns, TupleDomain predicates, VersionNumber cassandraVersion) { - this.clusteringColumns = ImmutableList.copyOf(requireNonNull(clusteringColumns, "clusteringColumns is null")); + this.clusteringColumns = ImmutableList.copyOf(clusteringColumns); this.predicates = requireNonNull(predicates, "predicates is null"); - this.clusteringPushDownResult = getClusteringKeysSet(clusteringColumns, predicates); + this.clusteringPushDownResult = getClusteringKeysSet(clusteringColumns, predicates, requireNonNull(cassandraVersion, "cassandraVersion is null")); } - public List getClusteringKeyPredicates() + public String getClusteringKeyPredicates() { - Set> pushedDownDomainValues = clusteringPushDownResult.getDomainValues(); - - if (pushedDownDomainValues.isEmpty()) { - return ImmutableList.of(); - } - - ImmutableList.Builder clusteringPredicates = ImmutableList.builder(); - for (List clusteringKeys : pushedDownDomainValues) { - if (clusteringKeys.isEmpty()) { - continue; - } - - StringBuilder stringBuilder = new StringBuilder(); - - for (int i = 0; i < clusteringKeys.size(); i++) { - if (i > 0) { - stringBuilder.append(" AND "); - } - - stringBuilder.append(CassandraCqlUtils.validColumnName(clusteringColumns.get(i).getName())); - stringBuilder.append(" = "); - stringBuilder.append(CassandraCqlUtils.cqlValue(toCQLCompatibleString(clusteringKeys.get(i)), clusteringColumns.get(i).getCassandraType())); - } - - clusteringPredicates.add(stringBuilder.toString()); - } - return clusteringPredicates.build(); + return clusteringPushDownResult.getDomainQuery(); } public TupleDomain getUnenforcedConstraints() @@ -87,65 +62,133 @@ public TupleDomain getUnenforcedConstraints() return TupleDomain.withColumnDomains(notPushedDown); } - private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates) + private static ClusteringPushDownResult getClusteringKeysSet(List clusteringColumns, TupleDomain predicates, VersionNumber cassandraVersion) { ImmutableMap.Builder domainsBuilder = ImmutableMap.builder(); - ImmutableList.Builder> clusteringColumnValues = ImmutableList.builder(); + ImmutableList.Builder clusteringColumnSql = ImmutableList.builder(); + int currentClusteringColumn = 0; for (CassandraColumnHandle columnHandle : clusteringColumns) { Domain domain = predicates.getDomains().get().get(columnHandle); - if (domain == null) { break; } - if (domain.isNullAllowed()) { - return new ClusteringPushDownResult(domainsBuilder.build(), ImmutableSet.of()); + break; } - - Set values = domain.getValues().getValuesProcessor().transform( + String predicateString = null; + predicateString = domain.getValues().getValuesProcessor().transform( ranges -> { - ImmutableSet.Builder columnValues = ImmutableSet.builder(); - for (Range range : ranges.getOrderedRanges()) { - if (!range.isSingleValue()) { - return ImmutableSet.of(); + List singleValues = new ArrayList<>(); + List rangeConjuncts = new ArrayList<>(); + String predicate = null; + + for (Range range : ranges.getOrderedRanges()) { + if (range.isAll()) { + return null; + } + if (range.isSingleValue()) { + singleValues.add(CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getSingleValue()), + columnHandle.getCassandraType())); + } + else { + if (!range.getLow().isLowerUnbounded()) { + switch (range.getLow().getBound()) { + case ABOVE: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " > " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getLow().getValue()), + columnHandle.getCassandraType())); + break; + case EXACTLY: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " >= " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getLow().getValue()), + columnHandle.getCassandraType())); + break; + case BELOW: + throw new VerifyException("Low Marker should never use BELOW bound"); + default: + throw new AssertionError("Unhandled bound: " + range.getLow().getBound()); } - /* TODO add code to handle a range of values for the last column - * Prior to Cassandra 2.2, only the last clustering column can have a range of values - * Take a look at how this is done in PreparedStatementBuilder.java - */ - - Object value = range.getSingleValue(); - - CassandraType valueType = columnHandle.getCassandraType(); - columnValues.add(valueType.validateClusteringKey(value)); } - return columnValues.build(); - }, - discreteValues -> { - if (discreteValues.isWhiteList()) { - return ImmutableSet.copyOf(discreteValues.getValues()); + if (!range.getHigh().isUpperUnbounded()) { + switch (range.getHigh().getBound()) { + case ABOVE: + throw new VerifyException("High Marker should never use ABOVE bound"); + case EXACTLY: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " <= " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getHigh().getValue()), + columnHandle.getCassandraType())); + break; + case BELOW: + rangeConjuncts.add(CassandraCqlUtils.validColumnName(columnHandle.getName()) + " < " + + CassandraCqlUtils.cqlValue(toCQLCompatibleString(range.getHigh().getValue()), + columnHandle.getCassandraType())); + break; + default: + throw new AssertionError("Unhandled bound: " + range.getHigh().getBound()); + } } - return ImmutableSet.of(); - }, - allOrNone -> ImmutableSet.of()); + } + } + + if (!singleValues.isEmpty() && !rangeConjuncts.isEmpty()) { + return null; + } + if (!singleValues.isEmpty()) { + if (singleValues.size() == 1) { + predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " = " + singleValues.get(0); + } + else { + predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + + Joiner.on(",").join(singleValues) + ")"; + } + } + else if (!rangeConjuncts.isEmpty()) { + predicate = Joiner.on(" AND ").join(rangeConjuncts); + } + return predicate; + }, discreteValues -> { + if (discreteValues.isWhiteList()) { + ImmutableList.Builder discreteValuesList = ImmutableList.builder(); + for (Object discreteValue : discreteValues.getValues()) { + discreteValuesList.add(CassandraCqlUtils.cqlValue(toCQLCompatibleString(discreteValue), + columnHandle.getCassandraType())); + } + String predicate = CassandraCqlUtils.validColumnName(columnHandle.getName()) + " IN (" + + Joiner.on(",").join(discreteValuesList.build()) + ")"; + return predicate; + } + return null; + }, allOrNone -> null); - if (!values.isEmpty()) { - clusteringColumnValues.add(values); - domainsBuilder.put(columnHandle, domain); + if (predicateString == null) { + break; + } + // IN restriction only on last clustering column for Cassandra version = 2.1 + if (predicateString.contains(" IN (") && cassandraVersion.compareTo(VersionNumber.parse("2.2.0")) < 0 && currentClusteringColumn != (clusteringColumns.size() - 1)) { + break; + } + clusteringColumnSql.add(predicateString); + domainsBuilder.put(columnHandle, domain); + // Check for last clustering column should only be restricted by range condition + if (predicateString.contains(">") || predicateString.contains("<")) { + break; } + currentClusteringColumn++; } - return new ClusteringPushDownResult(domainsBuilder.build(), cartesianProduct(clusteringColumnValues.build())); + List clusteringColumnPredicates = clusteringColumnSql.build(); + + return new ClusteringPushDownResult(domainsBuilder.build(), Joiner.on(" AND ").join(clusteringColumnPredicates)); } private static class ClusteringPushDownResult { private final Map domains; - private final Set> domainValues; + private final String domainQuery; - public ClusteringPushDownResult(Map domains, Set> domainValues) + public ClusteringPushDownResult(Map domains, String domainQuery) { this.domains = requireNonNull(ImmutableMap.copyOf(domains)); - this.domainValues = requireNonNull(ImmutableSet.copyOf(domainValues)); + this.domainQuery = requireNonNull(domainQuery); } public Map getDomains() @@ -153,9 +196,9 @@ public Map getDomains() return domains; } - public Set> getDomainValues() + public String getDomainQuery() { - return domainValues; + return domainQuery; } } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java index 45eeded00c66a..fcd4dcd5fc86a 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnectorRecordSinkProvider.java @@ -43,12 +43,28 @@ public RecordSink getRecordSink(ConnectorTransactionHandle transaction, Connecto checkArgument(tableHandle instanceof CassandraOutputTableHandle, "tableHandle is not an instance of CassandraOutputTableHandle"); CassandraOutputTableHandle handle = (CassandraOutputTableHandle) tableHandle; - return new CassandraRecordSink(handle, cassandraSession); + return new CassandraRecordSink( + cassandraSession, + handle.getSchemaName(), + handle.getTableName(), + handle.getColumnNames(), + handle.getColumnTypes(), + true); } @Override public RecordSink getRecordSink(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorInsertTableHandle tableHandle) { - throw new UnsupportedOperationException(); + requireNonNull(tableHandle, "tableHandle is null"); + checkArgument(tableHandle instanceof CassandraInsertTableHandle, "tableHandle is not an instance of ConnectorInsertTableHandle"); + CassandraInsertTableHandle handle = (CassandraInsertTableHandle) tableHandle; + + return new CassandraRecordSink( + cassandraSession, + handle.getSchemaName(), + handle.getTableName(), + handle.getColumnNames(), + handle.getColumnTypes(), + false); } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java index 362739abca764..70d415ebebba8 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraErrorCode.java @@ -22,7 +22,7 @@ public enum CassandraErrorCode implements ErrorCodeSupplier { - CASSANDRA_METADATA_ERROR(0, EXTERNAL); + CASSANDRA_METADATA_ERROR(0, EXTERNAL), CASSANDRA_VERSION_ERROR(1, EXTERNAL); private final ErrorCode errorCode; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java index f2fe188af7192..c82269ea0f22e 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraHandleResolver.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; @@ -59,4 +60,10 @@ public Class getTransactionHandleClass() { return CassandraTransactionHandle.class; } + + @Override + public Class getInsertTableHandleClass() + { + return CassandraInsertTableHandle.class; + } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraInsertTableHandle.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraInsertTableHandle.java new file mode 100644 index 0000000000000..57d1e55abeb04 --- /dev/null +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraInsertTableHandle.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cassandra; + +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class CassandraInsertTableHandle + implements ConnectorInsertTableHandle +{ + private final String connectorId; + private final String schemaName; + private final String tableName; + private final List columnNames; + private final List columnTypes; + + @JsonCreator + public CassandraInsertTableHandle( + @JsonProperty("connectorId") String connectorId, + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("columnNames") List columnNames, + @JsonProperty("columnTypes") List columnTypes) + { + this.connectorId = requireNonNull(connectorId, "clientId is null"); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + + requireNonNull(columnNames, "columnNames is null"); + requireNonNull(columnTypes, "columnTypes is null"); + checkArgument(columnNames.size() == columnTypes.size(), "columnNames and columnTypes sizes don't match"); + this.columnNames = ImmutableList.copyOf(columnNames); + this.columnTypes = ImmutableList.copyOf(columnTypes); + } + + @JsonProperty + public String getConnectorId() + { + return connectorId; + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public List getColumnNames() + { + return columnNames; + } + + @JsonProperty + public List getColumnTypes() + { + return columnTypes; + } + + @Override + public String toString() + { + return "cassandra:" + schemaName + "." + tableName; + } +} diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java index b3f465c84284e..3307f00f75bcd 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java @@ -16,6 +16,7 @@ import com.facebook.presto.cassandra.util.CassandraCqlUtils; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; @@ -30,6 +31,7 @@ import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.predicate.TupleDomain; @@ -46,12 +48,16 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.presto.cassandra.CassandraType.toCassandraType; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validSchemaName; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validTableName; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.PERMISSION_DENIED; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -60,7 +66,6 @@ public class CassandraMetadata implements ConnectorMetadata { private final String connectorId; - private final CachingCassandraSchemaProvider schemaProvider; private final CassandraSession cassandraSession; private final CassandraPartitionManager partitionManager; private final boolean allowDropTable; @@ -68,15 +73,14 @@ public class CassandraMetadata private final JsonCodec> extraColumnMetadataCodec; @Inject - public CassandraMetadata(CassandraConnectorId connectorId, - CachingCassandraSchemaProvider schemaProvider, + public CassandraMetadata( + CassandraConnectorId connectorId, CassandraSession cassandraSession, CassandraPartitionManager partitionManager, JsonCodec> extraColumnMetadataCodec, CassandraClientConfig config) { this.connectorId = requireNonNull(connectorId, "connectorId is null").toString(); - this.schemaProvider = requireNonNull(schemaProvider, "schemaProvider is null"); this.partitionManager = requireNonNull(partitionManager, "partitionManager is null"); this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); this.allowDropTable = requireNonNull(config, "config is null").getAllowDropTable(); @@ -86,7 +90,9 @@ public CassandraMetadata(CassandraConnectorId connectorId, @Override public List listSchemaNames(ConnectorSession session) { - return schemaProvider.getAllSchemas(); + return cassandraSession.getCaseSensitiveSchemaNames().stream() + .map(name -> name.toLowerCase(ENGLISH)) + .collect(toImmutableList()); } @Override @@ -94,11 +100,9 @@ public CassandraTableHandle getTableHandle(ConnectorSession session, SchemaTable { requireNonNull(tableName, "tableName is null"); try { - CassandraTableHandle tableHandle = schemaProvider.getTableHandle(tableName); - schemaProvider.getTable(tableHandle); - return tableHandle; + return cassandraSession.getTable(tableName).getTableHandle(); } - catch (NotFoundException e) { + catch (TableNotFoundException | SchemaNotFoundException e) { // table was not found return null; } @@ -113,15 +117,14 @@ private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { requireNonNull(tableHandle, "tableHandle is null"); - SchemaTableName tableName = getTableName(tableHandle); - return getTableMetadata(session, tableName); + return getTableMetadata(getTableName(tableHandle)); } - private ConnectorTableMetadata getTableMetadata(ConnectorSession session, SchemaTableName tableName) + private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) { - CassandraTableHandle tableHandle = schemaProvider.getTableHandle(tableName); - List columns = getColumnHandles(session, tableHandle).values().stream() - .map(column -> ((CassandraColumnHandle) column).getColumnMetadata()) + CassandraTable table = cassandraSession.getTable(tableName); + List columns = table.getColumns().stream() + .map(CassandraColumnHandle::getColumnMetadata) .collect(toList()); return new ConnectorTableMetadata(tableName, columns); } @@ -132,7 +135,7 @@ public List listTables(ConnectorSession session, String schemaN ImmutableList.Builder tableNames = ImmutableList.builder(); for (String schemaName : listSchemas(session, schemaNameOrNull)) { try { - for (String tableName : schemaProvider.getAllTables(schemaName)) { + for (String tableName : cassandraSession.getCaseSensitiveTableNames(schemaName)) { tableNames.add(new SchemaTableName(schemaName, tableName.toLowerCase(ENGLISH))); } } @@ -154,7 +157,9 @@ private List listSchemas(ConnectorSession session, String schemaNameOrNu @Override public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) { - CassandraTable table = schemaProvider.getTable((CassandraTableHandle) tableHandle); + requireNonNull(session, "session is null"); + requireNonNull(tableHandle, "tableHandle is null"); + CassandraTable table = cassandraSession.getTable(getTableName(tableHandle)); ImmutableMap.Builder columnHandles = ImmutableMap.builder(); for (CassandraColumnHandle columnHandle : table.getColumns()) { columnHandles.put(CassandraCqlUtils.cqlNameToSqlName(columnHandle.getName()).toLowerCase(ENGLISH), columnHandle); @@ -169,7 +174,7 @@ public Map> listTableColumns(ConnectorSess ImmutableMap.Builder> columns = ImmutableMap.builder(); for (SchemaTableName tableName : listTables(session, prefix)) { try { - columns.put(tableName, getTableMetadata(session, tableName).getColumns()); + columns.put(tableName, getTableMetadata(tableName).getColumns()); } catch (NotFoundException e) { // table disappeared during listing operation @@ -198,16 +203,16 @@ public List getTableLayouts(ConnectorSession session CassandraTableHandle handle = (CassandraTableHandle) table; CassandraPartitionResult partitionResult = partitionManager.getPartitions(handle, constraint.getSummary()); - List clusteringKeyPredicates; + String clusteringKeyPredicates = ""; TupleDomain unenforcedConstraint; if (partitionResult.isUnpartitioned()) { - clusteringKeyPredicates = ImmutableList.of(); unenforcedConstraint = partitionResult.getUnenforcedConstraint(); } else { CassandraClusteringPredicatesExtractor clusteringPredicatesExtractor = new CassandraClusteringPredicatesExtractor( - schemaProvider.getTable(handle).getClusteringKeyColumns(), - partitionResult.getUnenforcedConstraint()); + cassandraSession.getTable(getTableName(handle)).getClusteringKeyColumns(), + partitionResult.getUnenforcedConstraint(), + cassandraSession.getCassandraVersion()); clusteringKeyPredicates = clusteringPredicatesExtractor.getClusteringKeyPredicates(); unenforcedConstraint = clusteringPredicatesExtractor.getUnenforcedConstraints(); } @@ -253,7 +258,6 @@ public void dropTable(ConnectorSession session, ConnectorTableHandle tableHandle String tableName = cassandraTableHandle.getTableName(); cassandraSession.execute(String.format("DROP TABLE \"%s\".\"%s\"", schemaName, tableName)); - schemaProvider.flushTable(cassandraTableHandle.getSchemaTableName()); } @Override @@ -277,7 +281,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con // get the root directory for the database SchemaTableName table = tableMetadata.getTable(); - String schemaName = schemaProvider.getCaseSensitiveSchemaName(table.getSchemaName()); + String schemaName = cassandraSession.getCaseSensitiveSchemaName(table.getSchemaName()); String tableName = table.getTableName(); List columns = columnNames.build(); List types = columnTypes.build(); @@ -309,8 +313,29 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con @Override public Optional finishCreateTable(ConnectorSession session, ConnectorOutputTableHandle tableHandle, Collection fragments) { - CassandraOutputTableHandle outputTableHandle = (CassandraOutputTableHandle) tableHandle; - schemaProvider.flushTable(new SchemaTableName(outputTableHandle.getSchemaName(), outputTableHandle.getTableName())); + return Optional.empty(); + } + + @Override + public ConnectorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle) + { + CassandraTableHandle table = (CassandraTableHandle) tableHandle; + SchemaTableName schemaTableName = new SchemaTableName(table.getSchemaName(), table.getTableName()); + List columns = cassandraSession.getTable(schemaTableName).getColumns(); + List columnNames = columns.stream().map(CassandraColumnHandle::getName).map(CassandraCqlUtils::validColumnName).collect(Collectors.toList()); + List columnTypes = columns.stream().map(CassandraColumnHandle::getType).collect(Collectors.toList()); + + return new CassandraInsertTableHandle( + connectorId, + validSchemaName(table.getSchemaName()), + validTableName(table.getTableName()), + columnNames, + columnTypes); + } + + @Override + public Optional finishInsert(ConnectorSession session, ConnectorInsertTableHandle insertHandle, Collection fragments) + { return Optional.empty(); } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java index de561dd581180..dec232d7b4e65 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java @@ -42,13 +42,11 @@ public class CassandraPartitionManager { private static final Logger log = Logger.get(CassandraPartitionManager.class); - private final CachingCassandraSchemaProvider schemaProvider; private final CassandraSession cassandraSession; @Inject - public CassandraPartitionManager(CachingCassandraSchemaProvider schemaProvider, CassandraSession cassandraSession) + public CassandraPartitionManager(CassandraSession cassandraSession) { - this.schemaProvider = requireNonNull(schemaProvider, "schemaProvider is null"); this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); } @@ -56,7 +54,7 @@ public CassandraPartitionResult getPartitions(ConnectorTableHandle tableHandle, { CassandraTableHandle cassandraTableHandle = (CassandraTableHandle) tableHandle; - CassandraTable table = schemaProvider.getTable(cassandraTableHandle); + CassandraTable table = cassandraSession.getTable(cassandraTableHandle.getSchemaTableName()); List partitionKeys = table.getPartitionKeyColumns(); // fetch the partitions diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSink.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSink.java index 1ed416ead5f8a..52d26ba887c5c 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSink.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSink.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.cassandra; +import com.datastax.driver.core.PreparedStatement; +import com.datastax.driver.core.querybuilder.Insert; import com.facebook.presto.spi.RecordSink; import com.facebook.presto.spi.type.Type; import com.google.common.collect.ImmutableList; @@ -20,17 +22,24 @@ import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.ISODateTimeFormat; -import javax.inject.Inject; - +import java.nio.ByteBuffer; +import java.sql.Timestamp; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.UUID; import java.util.concurrent.TimeUnit; +import static com.datastax.driver.core.querybuilder.QueryBuilder.bindMarker; +import static com.datastax.driver.core.querybuilder.QueryBuilder.insertInto; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DateType.DATE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.RealType.REAL; +import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.spi.type.Varchars.isVarcharType; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.lang.Float.intBitsToFloat; import static java.nio.charset.StandardCharsets.UTF_8; @@ -41,37 +50,41 @@ public class CassandraRecordSink { private static final DateTimeFormatter DATE_FORMATTER = ISODateTimeFormat.date().withZoneUTC(); - private final int fieldCount; private final CassandraSession cassandraSession; - private final String insertQuery; + private final PreparedStatement insert; private final List values; private final List columnTypes; + private final boolean generateUUID; + private int field = -1; - @Inject - public CassandraRecordSink(CassandraOutputTableHandle handle, CassandraSession cassandraSession) + public CassandraRecordSink( + CassandraSession cassandraSession, + String schemaName, + String tableName, + List columnNames, + List columnTypes, + boolean generateUUID) { - this.fieldCount = requireNonNull(handle, "handle is null").getColumnNames().size(); - this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); - - String schemaName = handle.getSchemaName(); - StringBuilder queryBuilder = new StringBuilder(String.format("INSERT INTO \"%s\".\"%s\"(", schemaName, handle.getTableName())); - queryBuilder.append("id"); - - for (String columnName : handle.getColumnNames()) { - queryBuilder.append(",").append(columnName); + this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession"); + requireNonNull(schemaName, "schemaName is null"); + requireNonNull(tableName, "tableName is null"); + requireNonNull(columnNames, "columnNames is null"); + this.columnTypes = ImmutableList.copyOf(requireNonNull(columnTypes, "columnTypes is null")); + this.generateUUID = generateUUID; + + Insert insert = insertInto(schemaName, tableName); + if (generateUUID) { + insert.value("id", bindMarker()); } - queryBuilder.append(") VALUES (?"); - - for (int i = 0; i < handle.getColumnNames().size(); i++) { - queryBuilder.append(",?"); + for (int i = 0; i < columnNames.size(); i++) { + String columnName = columnNames.get(i); + checkArgument(columnName != null, "columnName is null at position: %d", i); + insert.value(columnName, bindMarker()); } - queryBuilder.append(")"); - - insertQuery = queryBuilder.toString(); - values = new ArrayList<>(); + this.insert = cassandraSession.prepare(insert); - columnTypes = handle.getColumnTypes(); + values = new ArrayList<>(columnTypes.size() + 1); } @Override @@ -81,16 +94,18 @@ public void beginRecord() field = 0; values.clear(); - values.add(UUID.randomUUID()); + if (generateUUID) { + values.add(UUID.randomUUID()); + } } @Override public void finishRecord() { checkState(field != -1, "not in record"); - checkState(field == fieldCount, "not all fields set"); + checkState(field == columnTypes.size(), "not all fields set"); field = -1; - cassandraSession.execute(insertQuery, values.toArray()); + cassandraSession.execute(insert.bind(values.toArray())); } @Override @@ -108,18 +123,25 @@ public void appendBoolean(boolean value) @Override public void appendLong(long value) { - if (DATE.equals(columnTypes.get(field))) { + Type columnType = columnTypes.get(field); + if (DATE.equals(columnType)) { append(DATE_FORMATTER.print(TimeUnit.DAYS.toMillis(value))); } - else if (INTEGER.equals(columnTypes.get(field))) { + else if (INTEGER.equals(columnType)) { append(((Number) value).intValue()); } - else if (REAL.equals(columnTypes.get(field))) { + else if (REAL.equals(columnType)) { append(intBitsToFloat((int) value)); } - else { + else if (TIMESTAMP.equals(columnType)) { + append(new Timestamp(value)); + } + else if (BIGINT.equals(columnType)) { append(value); } + else { + throw new UnsupportedOperationException("Type is not supported: " + columnType); + } } @Override @@ -131,7 +153,16 @@ public void appendDouble(double value) @Override public void appendString(byte[] value) { - append(new String(value, UTF_8)); + Type columnType = columnTypes.get(field); + if (VARBINARY.equals(columnType)) { + append(ByteBuffer.wrap(value)); + } + else if (isVarcharType(columnType)) { + append(new String(value, UTF_8)); + } + else { + throw new UnsupportedOperationException("Type is not supported: " + columnType); + } } @Override @@ -160,7 +191,7 @@ public List getColumnTypes() private void append(Object value) { checkState(field != -1, "not in record"); - checkState(field < fieldCount, "all fields already set"); + checkState(field < columnTypes.size(), "all fields already set"); values.add(value); field++; } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java index 11eb426a54caa..fa5d3d5274b30 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSession.java @@ -14,8 +14,12 @@ package com.facebook.presto.cassandra; import com.datastax.driver.core.Host; +import com.datastax.driver.core.PreparedStatement; +import com.datastax.driver.core.RegularStatement; import com.datastax.driver.core.ResultSet; -import com.datastax.driver.core.Session; +import com.datastax.driver.core.Statement; +import com.datastax.driver.core.TokenRange; +import com.datastax.driver.core.VersionNumber; import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; @@ -28,30 +32,33 @@ public interface CassandraSession { String PRESTO_COMMENT_METADATA = "Presto Metadata:"; - Set getReplicas(String schemaName, ByteBuffer partitionKey); + VersionNumber getCassandraVersion(); - List getAllSchemas(); + String getPartitioner(); - List getAllTables(String schema) - throws SchemaNotFoundException; + Set getTokenRanges(); + + Set getReplicas(String caseSensitiveSchemaName, TokenRange tokenRange); + + Set getReplicas(String caseSensitiveSchemaName, ByteBuffer partitionKey); - void getSchema(String schema) + String getCaseSensitiveSchemaName(String caseInsensitiveSchemaName); + + List getCaseSensitiveSchemaNames(); + + List getCaseSensitiveTableNames(String caseInsensitiveSchemaName) throws SchemaNotFoundException; - CassandraTable getTable(SchemaTableName tableName) + CassandraTable getTable(SchemaTableName schemaTableName) throws TableNotFoundException; List getPartitions(CassandraTable table, List filterPrefix); - default ResultSet execute(String cql, Object... values) - { - return executeWithSession(session -> session.execute(cql, values)); - } + ResultSet execute(String cql, Object... values); + + List getSizeEstimates(String keyspaceName, String tableName); - T executeWithSession(SessionCallable sessionCallable); + PreparedStatement prepare(RegularStatement statement); - interface SessionCallable - { - T executeWithSession(Session session); - } + ResultSet execute(Statement statement); } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java index e5d94b72c3c07..d0ce4dbbbb134 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java @@ -42,19 +42,17 @@ public class CassandraSplitManager { private final String connectorId; private final CassandraSession cassandraSession; - private final CachingCassandraSchemaProvider schemaProvider; private final int partitionSizeForBatchSelect; private final CassandraTokenSplitManager tokenSplitMgr; @Inject - public CassandraSplitManager(CassandraConnectorId connectorId, + public CassandraSplitManager( + CassandraConnectorId connectorId, CassandraClientConfig cassandraClientConfig, CassandraSession cassandraSession, - CachingCassandraSchemaProvider schemaProvider, CassandraTokenSplitManager tokenSplitMgr) { this.connectorId = requireNonNull(connectorId, "connectorId is null").toString(); - this.schemaProvider = requireNonNull(schemaProvider, "schemaProvider is null"); this.cassandraSession = requireNonNull(cassandraSession, "cassandraSession is null"); this.partitionSizeForBatchSelect = cassandraClientConfig.getPartitionSizeForBatchSelect(); this.tokenSplitMgr = tokenSplitMgr; @@ -75,7 +73,7 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, Co if (partitions.size() == 1) { CassandraPartition cassandraPartition = partitions.get(0); if (cassandraPartition.isUnpartitioned() || cassandraPartition.isIndexedColumnPredicatePushdown()) { - CassandraTable table = schemaProvider.getTable(cassandraTableHandle); + CassandraTable table = cassandraSession.getTable(cassandraTableHandle.getSchemaTableName()); List splits = getSplitsByTokenRange(table, cassandraPartition.getPartitionId()); return new FixedSplitSource(splits); } @@ -107,7 +105,7 @@ private static String buildTokenCondition(String tokenExpression, String startTo return tokenExpression + " > " + startToken + " AND " + tokenExpression + " <= " + endToken; } - private List getSplitsForPartitions(CassandraTableHandle cassTableHandle, List partitions, List clusteringPredicates) + private List getSplitsForPartitions(CassandraTableHandle cassTableHandle, List partitions, String clusteringPredicates) { String schema = cassTableHandle.getSchemaName(); HostAddressFactory hostAddressFactory = new HostAddressFactory(); @@ -150,7 +148,7 @@ private List getSplitsForPartitions(CassandraTableHandle cassTab hostMap.put(hostAddresses, addresses); } else { - builder.addAll(createSplitsForClusteringPredicates(cassTableHandle, cassandraPartition.getPartitionId(), addresses, clusteringPredicates)); + builder.add(createSplitForClusteringPredicates(cassTableHandle, cassandraPartition.getPartitionId(), addresses, clusteringPredicates)); } } if (singlePartitionKeyColumn) { @@ -165,7 +163,7 @@ private List getSplitsForPartitions(CassandraTableHandle cassTab size++; if (size > partitionSizeForBatchSelect) { String partitionId = String.format("%s in (%s)", partitionKeyColumnName, sb.toString()); - builder.addAll(createSplitsForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); + builder.add(createSplitForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); size = 0; sb.setLength(0); sb.trimToSize(); @@ -173,31 +171,27 @@ private List getSplitsForPartitions(CassandraTableHandle cassTab } if (size > 0) { String partitionId = String.format("%s in (%s)", partitionKeyColumnName, sb.toString()); - builder.addAll(createSplitsForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); + builder.add(createSplitForClusteringPredicates(cassTableHandle, partitionId, hostMap.get(entry.getKey()), clusteringPredicates)); } } } return builder.build(); } - private List createSplitsForClusteringPredicates( + private CassandraSplit createSplitForClusteringPredicates( CassandraTableHandle tableHandle, String partitionId, List hosts, - List clusteringPredicates) + String clusteringPredicates) { String schema = tableHandle.getSchemaName(); String table = tableHandle.getTableName(); if (clusteringPredicates.isEmpty()) { - return ImmutableList.of(new CassandraSplit(connectorId, schema, table, partitionId, null, hosts)); + return new CassandraSplit(connectorId, schema, table, partitionId, null, hosts); } - ImmutableList.Builder builder = ImmutableList.builder(); - for (String clusteringPredicate : clusteringPredicates) { - builder.add(new CassandraSplit(connectorId, schema, table, partitionId, clusteringPredicate, hosts)); - } - return builder.build(); + return new CassandraSplit(connectorId, schema, table, partitionId, clusteringPredicates, hosts); } @Override diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java index dfafda61ef6dc..fe2da9a38c013 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTableLayoutHandle.java @@ -28,19 +28,19 @@ public final class CassandraTableLayoutHandle { private final CassandraTableHandle table; private final List partitions; - private final List clusteringPredicates; + private final String clusteringPredicates; @JsonCreator public CassandraTableLayoutHandle(@JsonProperty("table") CassandraTableHandle table) { - this(table, ImmutableList.of(), ImmutableList.of()); + this(table, ImmutableList.of(), ""); } - public CassandraTableLayoutHandle(CassandraTableHandle table, List partitions, List clusteringPredicates) + public CassandraTableLayoutHandle(CassandraTableHandle table, List partitions, String clusteringPredicates) { this.table = requireNonNull(table, "table is null"); this.partitions = ImmutableList.copyOf(requireNonNull(partitions, "partition is null")); - this.clusteringPredicates = ImmutableList.copyOf(requireNonNull(clusteringPredicates, "clusteringPredicates is null")); + this.clusteringPredicates = requireNonNull(clusteringPredicates, "clusteringPredicates is null"); } @JsonProperty @@ -56,7 +56,7 @@ public List getPartitions() } @JsonIgnore - public List getClusteringPredicates() + public String getClusteringPredicates() { return clusteringPredicates; } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java index 9c55f7fcfee87..cc9fe2e400fb9 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java @@ -14,11 +14,6 @@ package com.facebook.presto.cassandra; import com.datastax.driver.core.Host; -import com.datastax.driver.core.KeyspaceMetadata; -import com.datastax.driver.core.ResultSet; -import com.datastax.driver.core.Row; -import com.datastax.driver.core.Statement; -import com.datastax.driver.core.TableMetadata; import com.datastax.driver.core.TokenRange; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; @@ -32,11 +27,8 @@ import java.util.Set; import java.util.concurrent.ThreadLocalRandom; -import static com.datastax.driver.core.querybuilder.QueryBuilder.eq; -import static com.datastax.driver.core.querybuilder.QueryBuilder.select; import static com.facebook.presto.cassandra.CassandraErrorCode.CASSANDRA_METADATA_ERROR; import static com.facebook.presto.cassandra.TokenRing.createForPartitioner; -import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.lang.Math.max; @@ -49,9 +41,6 @@ public class CassandraTokenSplitManager { - private static final String SYSTEM = "system"; - private static final String SIZE_ESTIMATES = "size_estimates"; - private final CassandraSession session; private final int splitSize; @@ -69,7 +58,7 @@ public CassandraTokenSplitManager(CassandraSession session, int splitSize) public List getSplits(String keyspace, String table) { - Set tokenRanges = getTokenRanges(); + Set tokenRanges = session.getTokenRanges(); if (tokenRanges.isEmpty()) { throw new PrestoException(CASSANDRA_METADATA_ERROR, "The cluster metadata is not available. " + @@ -81,7 +70,7 @@ public List getSplits(String keyspace, String table) tokenRanges = unwrap(tokenRanges); } - Optional tokenRing = createForPartitioner(getPartitioner()); + Optional tokenRing = createForPartitioner(session.getPartitioner()); long totalPartitionsCount = getTotalPartitionsCount(keyspace, table); List splits = new ArrayList<>(); @@ -118,11 +107,6 @@ public List getSplits(String keyspace, String table) return unmodifiableList(splits); } - private Set getTokenRanges() - { - return session.executeWithSession(session -> session.getCluster().getMetadata().getTokenRanges()); - } - private Set unwrap(Set tokenRanges) { ImmutableSet.Builder result = ImmutableSet.builder(); @@ -134,58 +118,20 @@ private Set unwrap(Set tokenRanges) private long getTotalPartitionsCount(String keyspace, String table) { - List estimates = getSizeEstimates(keyspace, table); + List estimates = session.getSizeEstimates(keyspace, table); return estimates.stream() .mapToLong(SizeEstimate::getPartitionsCount) .sum(); } - private List getSizeEstimates(String keyspaceName, String tableName) - { - checkSizeEstimatesTableExist(); - - Statement statement = select("range_start", "range_end", "mean_partition_size", "partitions_count") - .from(SYSTEM, SIZE_ESTIMATES) - .where(eq("keyspace_name", keyspaceName)) - .and(eq("table_name", tableName)); - - ResultSet result = session.executeWithSession(session -> session.execute(statement)); - ImmutableList.Builder estimates = ImmutableList.builder(); - for (Row row : result.all()) { - SizeEstimate estimate = new SizeEstimate( - row.getString("range_start"), - row.getString("range_end"), - row.getLong("mean_partition_size"), - row.getLong("partitions_count")); - estimates.add(estimate); - } - - return estimates.build(); - } - - private void checkSizeEstimatesTableExist() - { - KeyspaceMetadata ks = session.executeWithSession(session -> session.getCluster().getMetadata().getKeyspace(SYSTEM)); - checkState(ks != null, "system keyspace metadata must not be null"); - TableMetadata table = ks.getTable(SIZE_ESTIMATES); - if (table == null) { - throw new PrestoException(NOT_SUPPORTED, "Cassandra versions prior to 2.1.5 are not supported"); - } - } - private List getEndpoints(String keyspace, TokenRange tokenRange) { - Set endpoints = session.executeWithSession(session -> session.getCluster().getMetadata().getReplicas(keyspace, tokenRange)); + Set endpoints = session.getReplicas(keyspace, tokenRange); return unmodifiableList(endpoints.stream() .map(Host::toString) .collect(toList())); } - private String getPartitioner() - { - return session.executeWithSession(session -> session.getCluster().getMetadata().getPartitioner()); - } - private static TokenSplit createSplit(TokenRange range, List endpoints) { checkArgument(!range.isEmpty(), "tokenRange must not be empty"); @@ -222,40 +168,4 @@ public List getHosts() return hosts; } } - - private static class SizeEstimate - { - private final String rangeStart; - private final String rangeEnd; - private final long meanPartitionSize; - private final long partitionsCount; - - public SizeEstimate(String rangeStart, String rangeEnd, long meanPartitionSize, long partitionsCount) - { - this.rangeStart = requireNonNull(rangeStart, "rangeStart is null"); - this.rangeEnd = requireNonNull(rangeEnd, "rangeEnd is null"); - this.meanPartitionSize = meanPartitionSize; - this.partitionsCount = partitionsCount; - } - - public String getRangeStart() - { - return rangeStart; - } - - public String getRangeEnd() - { - return rangeEnd; - } - - public long getMeanPartitionSize() - { - return meanPartitionSize; - } - - public long getPartitionsCount() - { - return partitionsCount; - } - } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java index d46b0a79f2135..bd08ff21fdbfb 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java @@ -19,9 +19,15 @@ import com.datastax.driver.core.Host; import com.datastax.driver.core.IndexMetadata; import com.datastax.driver.core.KeyspaceMetadata; +import com.datastax.driver.core.PreparedStatement; +import com.datastax.driver.core.RegularStatement; +import com.datastax.driver.core.ResultSet; import com.datastax.driver.core.Row; import com.datastax.driver.core.Session; +import com.datastax.driver.core.Statement; import com.datastax.driver.core.TableMetadata; +import com.datastax.driver.core.TokenRange; +import com.datastax.driver.core.VersionNumber; import com.datastax.driver.core.exceptions.NoHostAvailableException; import com.datastax.driver.core.policies.ReconnectionPolicy; import com.datastax.driver.core.policies.ReconnectionPolicy.ReconnectionSchedule; @@ -30,6 +36,7 @@ import com.datastax.driver.core.querybuilder.Select; import com.facebook.presto.cassandra.util.CassandraCqlUtils; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; @@ -46,29 +53,41 @@ import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; +import static com.datastax.driver.core.querybuilder.QueryBuilder.eq; +import static com.datastax.driver.core.querybuilder.QueryBuilder.select; import static com.datastax.driver.core.querybuilder.Select.Where; +import static com.facebook.presto.cassandra.CassandraErrorCode.CASSANDRA_VERSION_ERROR; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validSchemaName; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.in; import static com.google.common.base.Predicates.not; import static com.google.common.base.Suppliers.memoize; import static com.google.common.collect.Iterables.filter; import static com.google.common.collect.Iterables.transform; +import static java.lang.String.format; import static java.util.Comparator.comparing; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; -// TODO: Refactor this class to make it be "single responsibility" public class NativeCassandraSession implements CassandraSession { private static final Logger log = Logger.get(NativeCassandraSession.class); - static final String PRESTO_COMMENT_METADATA = "Presto Metadata:"; - protected final String connectorId; + private static final String PRESTO_COMMENT_METADATA = "Presto Metadata:"; + private static final String SYSTEM = "system"; + private static final String SIZE_ESTIMATES = "size_estimates"; + + private final String connectorId; private final JsonCodec> extraColumnMetadataCodec; private final Cluster cluster; private final Supplier session; @@ -84,13 +103,56 @@ public NativeCassandraSession(String connectorId, JsonCodec getReplicas(String schemaName, ByteBuffer partitionKey) + public VersionNumber getCassandraVersion() + { + ResultSet result = executeWithSession(session -> session.execute("select release_version from system.local")); + Row versionRow = result.one(); + if (versionRow == null) { + throw new PrestoException(CASSANDRA_VERSION_ERROR, "The cluster version is not available. " + + "Please make sure that the Cassandra cluster is up and running, " + + "and that the contact points are specified correctly."); + } + return VersionNumber.parse(versionRow.getString("release_version")); + } + + @Override + public String getPartitioner() + { + return executeWithSession(session -> session.getCluster().getMetadata().getPartitioner()); + } + + @Override + public Set getTokenRanges() + { + return executeWithSession(session -> session.getCluster().getMetadata().getTokenRanges()); + } + + @Override + public Set getReplicas(String caseSensitiveSchemaName, TokenRange tokenRange) + { + requireNonNull(caseSensitiveSchemaName, "keyspace is null"); + requireNonNull(tokenRange, "tokenRange is null"); + return executeWithSession(session -> + session.getCluster().getMetadata().getReplicas(validSchemaName(caseSensitiveSchemaName), tokenRange)); + } + + @Override + public Set getReplicas(String caseSensitiveSchemaName, ByteBuffer partitionKey) + { + requireNonNull(caseSensitiveSchemaName, "keyspace is null"); + requireNonNull(partitionKey, "partitionKey is null"); + return executeWithSession(session -> + session.getCluster().getMetadata().getReplicas(validSchemaName(caseSensitiveSchemaName), partitionKey)); + } + + @Override + public String getCaseSensitiveSchemaName(String caseInsensitiveSchemaName) { - return executeWithSession(session -> session.getCluster().getMetadata().getReplicas(schemaName, partitionKey)); + return getKeyspaceByCaseInsensitiveName(caseInsensitiveSchemaName).getName(); } @Override - public List getAllSchemas() + public List getCaseSensitiveSchemaNames() { ImmutableList.Builder builder = ImmutableList.builder(); List keyspaces = executeWithSession(session -> session.getCluster().getMetadata().getKeyspaces()); @@ -101,42 +163,28 @@ public List getAllSchemas() } @Override - public List getAllTables(String schema) + public List getCaseSensitiveTableNames(String caseInsensitiveSchemaName) throws SchemaNotFoundException { - KeyspaceMetadata meta = getCheckedKeyspaceMetadata(schema); + KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(caseInsensitiveSchemaName); ImmutableList.Builder builder = ImmutableList.builder(); - for (TableMetadata tableMeta : meta.getTables()) { - builder.add(tableMeta.getName()); + for (TableMetadata table : keyspace.getTables()) { + builder.add(table.getName()); } return builder.build(); } - private KeyspaceMetadata getCheckedKeyspaceMetadata(String schema) - throws SchemaNotFoundException - { - KeyspaceMetadata keyspaceMetadata = executeWithSession(session -> session.getCluster().getMetadata().getKeyspace(schema)); - if (keyspaceMetadata == null) { - throw new SchemaNotFoundException(schema); - } - return keyspaceMetadata; - } - - @Override - public void getSchema(String schema) - throws SchemaNotFoundException - { - getCheckedKeyspaceMetadata(schema); - } - @Override - public CassandraTable getTable(SchemaTableName tableName) + public CassandraTable getTable(SchemaTableName schemaTableName) throws TableNotFoundException { - TableMetadata tableMeta = getTableMetadata(tableName); + KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(schemaTableName.getSchemaName()); + TableMetadata tableMeta = getTableMetadata(keyspace, schemaTableName.getTableName()); List columnNames = new ArrayList<>(); - for (ColumnMetadata columnMetadata : tableMeta.getColumns()) { + List columns = tableMeta.getColumns(); + checkColumnNames(columns); + for (ColumnMetadata columnMetadata : columns) { columnNames.add(columnMetadata.getName()); } @@ -178,7 +226,7 @@ public CassandraTable getTable(SchemaTableName tableName) } // add other columns - for (ColumnMetadata columnMeta : tableMeta.getColumns()) { + for (ColumnMetadata columnMeta : columns) { if (!primaryKeySet.contains(columnMeta.getName())) { boolean hidden = hiddenColumns.contains(columnMeta.getName()); CassandraColumnHandle columnHandle = buildColumnHandle(tableMeta, columnMeta, false, false, columnNames.indexOf(columnMeta.getName()), hidden); @@ -194,23 +242,66 @@ public CassandraTable getTable(SchemaTableName tableName) return new CassandraTable(tableHandle, sortedColumnHandles); } - private TableMetadata getTableMetadata(SchemaTableName schemaTableName) + private KeyspaceMetadata getKeyspaceByCaseInsensitiveName(String caseInsensitiveSchemaName) + throws SchemaNotFoundException { - String schemaName = schemaTableName.getSchemaName(); - String tableName = schemaTableName.getTableName(); + List keyspaces = executeWithSession(session -> session.getCluster().getMetadata().getKeyspaces()); + KeyspaceMetadata result = null; + // Ensure that the error message is deterministic + List sortedKeyspaces = Ordering.from(comparing(KeyspaceMetadata::getName)).immutableSortedCopy(keyspaces); + for (KeyspaceMetadata keyspace : sortedKeyspaces) { + if (keyspace.getName().equalsIgnoreCase(caseInsensitiveSchemaName)) { + if (result != null) { + throw new PrestoException( + NOT_SUPPORTED, + format("More than one keyspace has been found for the case insensitive schema name: %s -> (%s, %s)", + caseInsensitiveSchemaName, result.getName(), keyspace.getName())); + } + result = keyspace; + } + } + if (result == null) { + throw new SchemaNotFoundException(caseInsensitiveSchemaName); + } + return result; + } - KeyspaceMetadata keyspaceMetadata = getCheckedKeyspaceMetadata(schemaName); - TableMetadata tableMetadata = keyspaceMetadata.getTable(tableName); - if (tableMetadata != null) { - return tableMetadata; + private static TableMetadata getTableMetadata(KeyspaceMetadata keyspace, String caseInsensitiveTableName) + { + TableMetadata result = null; + Collection tables = keyspace.getTables(); + // Ensure that the error message is deterministic + List sortedTables = Ordering.from(comparing(TableMetadata::getName)).immutableSortedCopy(tables); + for (TableMetadata table : sortedTables) { + if (table.getName().equalsIgnoreCase(caseInsensitiveTableName)) { + if (result != null) { + throw new PrestoException( + NOT_SUPPORTED, + format("More than one table has been found for the case insensitive table name: %s -> (%s, %s)", + caseInsensitiveTableName, result.getName(), table.getName())); + } + result = table; + } } + if (result == null) { + throw new TableNotFoundException(new SchemaTableName(keyspace.getName(), caseInsensitiveTableName)); + } + return result; + } - for (TableMetadata table : keyspaceMetadata.getTables()) { - if (table.getName().equalsIgnoreCase(tableName)) { - return table; + private static void checkColumnNames(List columns) + { + Map lowercaseNameToColumnMap = new HashMap<>(); + for (ColumnMetadata column : columns) { + String lowercaseName = column.getName().toLowerCase(ENGLISH); + if (lowercaseNameToColumnMap.containsKey(lowercaseName)) { + throw new PrestoException( + NOT_SUPPORTED, + format("More than one column has been found for the case insensitive column name: %s -> (%s, %s)", + lowercaseName, lowercaseNameToColumnMap.get(lowercaseName).getName(), column.getName())); } + lowercaseNameToColumnMap.put(lowercaseName, column); } - throw new TableNotFoundException(schemaTableName); } private CassandraColumnHandle buildColumnHandle(TableMetadata tableMetadata, ColumnMetadata columnMeta, boolean partitionKey, boolean clusteringKey, int ordinalPosition, boolean hidden) @@ -297,7 +388,25 @@ public List getPartitions(CassandraTable table, List return partitions.build(); } - protected Iterable queryPartitionKeys(CassandraTable table, List filterPrefix) + @Override + public ResultSet execute(String cql, Object... values) + { + return executeWithSession(session -> session.execute(cql, values)); + } + + @Override + public PreparedStatement prepare(RegularStatement statement) + { + return executeWithSession(session -> session.prepare(statement)); + } + + @Override + public ResultSet execute(Statement statement) + { + return executeWithSession(session -> session.execute(statement)); + } + + private Iterable queryPartitionKeys(CassandraTable table, List filterPrefix) { CassandraTableHandle tableHandle = table.getTableHandle(); List partitionKeyColumns = table.getPartitionKeyColumns(); @@ -308,11 +417,53 @@ protected Iterable queryPartitionKeys(CassandraTable table, List fi Select partitionKeys = CassandraCqlUtils.selectDistinctFrom(tableHandle, partitionKeyColumns); addWhereClause(partitionKeys.where(), partitionKeyColumns, filterPrefix); - return executeWithSession(session -> session.execute(partitionKeys)).all(); + return execute(partitionKeys).all(); + } + + private static void addWhereClause(Where where, List partitionKeyColumns, List filterPrefix) + { + for (int i = 0; i < filterPrefix.size(); i++) { + CassandraColumnHandle column = partitionKeyColumns.get(i); + Object value = column.getCassandraType().getJavaValue(filterPrefix.get(i)); + Clause clause = QueryBuilder.eq(CassandraCqlUtils.validColumnName(column.getName()), value); + where.and(clause); + } } @Override - public T executeWithSession(SessionCallable sessionCallable) + public List getSizeEstimates(String keyspaceName, String tableName) + { + checkSizeEstimatesTableExist(); + Statement statement = select("range_start", "range_end", "mean_partition_size", "partitions_count") + .from(SYSTEM, SIZE_ESTIMATES) + .where(eq("keyspace_name", keyspaceName)) + .and(eq("table_name", tableName)); + + ResultSet result = executeWithSession(session -> session.execute(statement)); + ImmutableList.Builder estimates = ImmutableList.builder(); + for (Row row : result.all()) { + SizeEstimate estimate = new SizeEstimate( + row.getString("range_start"), + row.getString("range_end"), + row.getLong("mean_partition_size"), + row.getLong("partitions_count")); + estimates.add(estimate); + } + + return estimates.build(); + } + + private void checkSizeEstimatesTableExist() + { + KeyspaceMetadata keyspaceMetadata = executeWithSession(session -> session.getCluster().getMetadata().getKeyspace(SYSTEM)); + checkState(keyspaceMetadata != null, "system keyspace metadata must not be null"); + TableMetadata table = keyspaceMetadata.getTable(SIZE_ESTIMATES); + if (table == null) { + throw new PrestoException(NOT_SUPPORTED, "Cassandra versions prior to 2.1.5 are not supported"); + } + } + + private T executeWithSession(SessionCallable sessionCallable) { ReconnectionPolicy reconnectionPolicy = cluster.getConfiguration().getPolicies().getReconnectionPolicy(); ReconnectionSchedule schedule = reconnectionPolicy.newSchedule(); @@ -342,13 +493,8 @@ public T executeWithSession(SessionCallable sessionCallable) } } - private static void addWhereClause(Where where, List partitionKeyColumns, List filterPrefix) + private interface SessionCallable { - for (int i = 0; i < filterPrefix.size(); i++) { - CassandraColumnHandle column = partitionKeyColumns.get(i); - Object value = column.getCassandraType().getJavaValue(filterPrefix.get(i)); - Clause clause = QueryBuilder.eq(CassandraCqlUtils.validColumnName(column.getName()), value); - where.and(clause); - } + T executeWithSession(Session session); } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ReopeningCluster.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ReopeningCluster.java new file mode 100644 index 0000000000000..a8c988b460ed4 --- /dev/null +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ReopeningCluster.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cassandra; + +import com.datastax.driver.core.CloseFuture; +import com.datastax.driver.core.Cluster; +import com.datastax.driver.core.DelegatingCluster; +import io.airlift.log.Logger; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +public class ReopeningCluster + extends DelegatingCluster +{ + private static final Logger log = Logger.get(ReopeningCluster.class); + + @GuardedBy("this") + private Cluster delegate; + @GuardedBy("this") + private boolean closed; + + private final Supplier supplier; + + public ReopeningCluster(Supplier supplier) + { + this.supplier = requireNonNull(supplier, "supplier is null"); + } + + @Override + protected synchronized Cluster delegate() + { + checkState(!closed, "Cluster has been closed"); + + if (delegate == null) { + delegate = supplier.get(); + } + + if (delegate.isClosed()) { + log.warn("Cluster has been closed internally"); + delegate = supplier.get(); + } + + verify(!delegate.isClosed(), "Newly created cluster has been immediately closed"); + + return delegate; + } + + @Override + public synchronized void close() + { + closed = true; + if (delegate != null) { + delegate.close(); + delegate = null; + } + } + + @Override + public synchronized boolean isClosed() + { + return closed; + } + + @Override + public synchronized CloseFuture closeAsync() + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/SizeEstimate.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/SizeEstimate.java new file mode 100644 index 0000000000000..5779f9ffe18dd --- /dev/null +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/SizeEstimate.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cassandra; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class SizeEstimate +{ + private final String rangeStart; + private final String rangeEnd; + private final long meanPartitionSize; + private final long partitionsCount; + + public SizeEstimate(String rangeStart, String rangeEnd, long meanPartitionSize, long partitionsCount) + { + this.rangeStart = requireNonNull(rangeStart, "rangeStart is null"); + this.rangeEnd = requireNonNull(rangeEnd, "rangeEnd is null"); + this.meanPartitionSize = meanPartitionSize; + this.partitionsCount = partitionsCount; + } + + public String getRangeStart() + { + return rangeStart; + } + + public String getRangeEnd() + { + return rangeEnd; + } + + public long getMeanPartitionSize() + { + return meanPartitionSize; + } + + public long getPartitionsCount() + { + return partitionsCount; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SizeEstimate that = (SizeEstimate) o; + return meanPartitionSize == that.meanPartitionSize && + partitionsCount == that.partitionsCount && + Objects.equals(rangeStart, that.rangeStart) && + Objects.equals(rangeEnd, that.rangeEnd); + } + + @Override + public int hashCode() + { + return Objects.hash(rangeStart, rangeEnd, meanPartitionSize, partitionsCount); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("rangeStart", rangeStart) + .add("rangeEnd", rangeEnd) + .add("meanPartitionSize", meanPartitionSize) + .add("partitionsCount", partitionsCount) + .toString(); + } +} diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java index 0f95f82a1f429..8dbba13e5b1c0 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java @@ -55,9 +55,8 @@ public static synchronized DistributedQueryRunner createCassandraQueryRunner() List> tables = TpchTable.getTables(); copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, createCassandraSession("tpch"), tables); for (TpchTable table : tables) { - EmbeddedCassandra.flush("tpch", table.getTableName()); + EmbeddedCassandra.refreshSizeEstimates("tpch", table.getTableName()); } - EmbeddedCassandra.refreshSizeEstimates(); tpchLoaded = true; } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java index 936f84946e1f1..26334b8434fca 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java @@ -33,21 +33,25 @@ public class CassandraTestingUtils { public static final String TABLE_ALL_TYPES = "table_all_types"; + public static final String TABLE_ALL_TYPES_INSERT = "table_all_types_insert"; public static final String TABLE_ALL_TYPES_PARTITION_KEY = "table_all_types_partition_key"; public static final String TABLE_CLUSTERING_KEYS = "table_clustering_keys"; public static final String TABLE_CLUSTERING_KEYS_LARGE = "table_clustering_keys_large"; public static final String TABLE_MULTI_PARTITION_CLUSTERING_KEYS = "table_multi_partition_clustering_keys"; + public static final String TABLE_CLUSTERING_KEYS_INEQUALITY = "table_clustering_keys_inequality"; private CassandraTestingUtils() {} public static void createTestTables(CassandraSession cassandraSession, String keyspace, Date date) { createKeyspace(cassandraSession, keyspace); - createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES), date); + createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES), date, 9); + createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES_INSERT), date, 0); createTableAllTypesPartitionKey(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES_PARTITION_KEY), date); createTableClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS), 9); createTableClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS_LARGE), 1000); createTableMultiPartitionClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_MULTI_PARTITION_CLUSTERING_KEYS)); + createTableClusteringKeysInequality(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS_INEQUALITY), date, 4); } public static void createKeyspace(CassandraSession session, String keyspaceName) @@ -77,7 +81,7 @@ public static void insertIntoTableClusteringKeys(CassandraSession session, Schem .value("clust_one", "clust_one") .value("clust_two", "clust_two_" + rowNumber.toString()) .value("clust_three", "clust_three_" + rowNumber.toString()); - session.executeWithSession(s -> s.execute(insert)); + session.execute(insert); } assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), rowsCount); } @@ -106,12 +110,39 @@ public static void insertIntoTableMultiPartitionClusteringKeys(CassandraSession .value("clust_one", "clust_one") .value("clust_two", "clust_two_" + rowNumber.toString()) .value("clust_three", "clust_three_" + rowNumber.toString()); - session.executeWithSession(s -> s.execute(insert)); + session.execute(insert); } assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), 9); } - public static void createTableAllTypes(CassandraSession session, SchemaTableName table, Date date) + public static void createTableClusteringKeysInequality(CassandraSession session, SchemaTableName table, Date date, int rowsCount) + { + session.execute("DROP TABLE IF EXISTS " + table); + session.execute("CREATE TABLE " + table + " (" + + "key text, " + + "clust_one text, " + + "clust_two int, " + + "clust_three timestamp, " + + "data text, " + + "PRIMARY KEY((key), clust_one, clust_two, clust_three) " + + ")"); + insertIntoTableClusteringKeysInequality(session, table, date, rowsCount); + } + + public static void insertIntoTableClusteringKeysInequality(CassandraSession session, SchemaTableName table, Date date, int rowsCount) + { + for (Integer rowNumber = 1; rowNumber <= rowsCount; rowNumber++) { + Insert insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) + .value("key", "key_1") + .value("clust_one", "clust_one") + .value("clust_two", rowNumber) + .value("clust_three", date.getTime() + rowNumber * 10); + session.execute(insert); + } + assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), rowsCount); + } + + public static void createTableAllTypes(CassandraSession session, SchemaTableName table, Date date, int rowsCount) { session.execute("DROP TABLE IF EXISTS " + table); session.execute("CREATE TABLE " + table + " (" + @@ -134,7 +165,7 @@ public static void createTableAllTypes(CassandraSession session, SchemaTableName " typemap map, " + " typeset set, " + ")"); - insertTestData(session, table, date); + insertTestData(session, table, date, rowsCount); } public static void createTableAllTypesPartitionKey(CassandraSession session, SchemaTableName table, Date date) @@ -186,12 +217,12 @@ public static void createTableAllTypesPartitionKey(CassandraSession session, Sch " ))" + ")"); - insertTestData(session, table, date); + insertTestData(session, table, date, 9); } - private static void insertTestData(CassandraSession session, SchemaTableName table, Date date) + private static void insertTestData(CassandraSession session, SchemaTableName table, Date date, int rowsCount) { - for (Integer rowNumber = 1; rowNumber < 10; rowNumber++) { + for (Integer rowNumber = 1; rowNumber <= rowsCount; rowNumber++) { Insert insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) .value("key", "key " + rowNumber.toString()) .value("typeuuid", UUID.fromString(String.format("00000000-0000-0000-0000-%012d", rowNumber))) @@ -212,8 +243,8 @@ private static void insertTestData(CassandraSession session, SchemaTableName tab .value("typemap", ImmutableMap.of(rowNumber, rowNumber + 1L, rowNumber + 2, rowNumber + 3L)) .value("typeset", ImmutableSet.of(false, true)); - session.executeWithSession(s -> s.execute(insert)); + session.execute(insert); } - assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), 9); + assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), rowsCount); } } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java index 0b3c4c3bd7cb0..dd6b0f9e9f61b 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/EmbeddedCassandra.java @@ -32,15 +32,18 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.List; +import java.util.concurrent.TimeoutException; import static com.datastax.driver.core.ProtocolVersion.V3; import static com.google.common.base.Preconditions.checkState; import static com.google.common.io.Files.createTempDir; import static com.google.common.io.Files.write; import static com.google.common.io.Resources.getResource; +import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.testng.Assert.assertEquals; public final class EmbeddedCassandra @@ -50,6 +53,8 @@ public final class EmbeddedCassandra private static final String HOST = "127.0.0.1"; private static final int PORT = 9142; + private static final Duration REFRESH_SIZE_ESTIMATES_TIMEOUT = new Duration(1, MINUTES); + private static CassandraSession session; private static boolean initialized; @@ -76,6 +81,7 @@ public static synchronized void start() .withClusterName("TestCluster") .addContactPointsWithPorts(ImmutableList.of( new InetSocketAddress(HOST, PORT))) + .withMaxSchemaAgreementWaitSeconds(30) .build(); CassandraSession session = new NativeCassandraSession( @@ -148,7 +154,24 @@ private static void checkConnectivity(CassandraSession session) log.info("Cassandra version: %s", version); } - public static void flush(String keyspace, String table) + public static void refreshSizeEstimates(String keyspace, String table) + throws Exception + { + long deadline = System.nanoTime() + REFRESH_SIZE_ESTIMATES_TIMEOUT.roundTo(NANOSECONDS); + while (System.nanoTime() < deadline) { + flushTable(keyspace, table); + refreshSizeEstimates(); + List sizeEstimates = getSession().getSizeEstimates(keyspace, table); + if (!sizeEstimates.isEmpty()) { + log.info("Size estimates for the table %s.%s have been refreshed successfully: %s", keyspace, table, sizeEstimates); + return; + } + log.info("Size estimates haven't been refreshed as expected. Retrying ..."); + } + throw new TimeoutException(format("Attempting to refresh size estimates for table %s.%s has timed out after %s", keyspace, table, REFRESH_SIZE_ESTIMATES_TIMEOUT)); + } + + private static void flushTable(String keyspace, String table) throws Exception { ManagementFactory @@ -160,7 +183,7 @@ public static void flush(String keyspace, String table) new String[] {"java.lang.String", "[Ljava.lang.String;"}); } - public static void refreshSizeEstimates() + private static void refreshSizeEstimates() throws Exception { ManagementFactory diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/MockCassandraSession.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/MockCassandraSession.java deleted file mode 100644 index 3e4365d3f2bee..0000000000000 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/MockCassandraSession.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.cassandra; - -import com.datastax.driver.core.Host; -import com.facebook.presto.spi.SchemaNotFoundException; -import com.facebook.presto.spi.SchemaTableName; -import com.facebook.presto.spi.TableNotFoundException; -import com.google.common.collect.ImmutableList; - -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; - -import static java.util.Objects.requireNonNull; - -public class MockCassandraSession - implements CassandraSession -{ - static final String TEST_SCHEMA = "testkeyspace"; - static final String BAD_SCHEMA = "badkeyspace"; - static final String TEST_TABLE = "testtbl"; - static final String TEST_COLUMN1 = "column1"; - static final String TEST_COLUMN2 = "column2"; - - private final AtomicInteger accessCount = new AtomicInteger(); - private boolean throwException; - private final String connectorId; - - public MockCassandraSession(String connectorId) - { - this.connectorId = requireNonNull(connectorId, "connectorId is null"); - } - - public void setThrowException(boolean throwException) - { - this.throwException = throwException; - } - - public int getAccessCount() - { - return accessCount.get(); - } - - @Override - public List getAllSchemas() - { - accessCount.incrementAndGet(); - - if (throwException) { - throw new IllegalStateException(); - } - return ImmutableList.of(TEST_SCHEMA); - } - - @Override - public List getAllTables(String schema) - throws SchemaNotFoundException - { - accessCount.incrementAndGet(); - if (throwException) { - throw new IllegalStateException(); - } - - if (schema.equals(TEST_SCHEMA)) { - return ImmutableList.of(TEST_TABLE); - } - throw new SchemaNotFoundException(schema); - } - - @Override - public void getSchema(String schema) - throws SchemaNotFoundException - { - accessCount.incrementAndGet(); - if (throwException) { - throw new IllegalStateException(); - } - - if (!schema.equals(TEST_SCHEMA)) { - throw new SchemaNotFoundException(schema); - } - } - - @Override - public CassandraTable getTable(SchemaTableName tableName) - throws TableNotFoundException - { - accessCount.incrementAndGet(); - if (throwException) { - throw new IllegalStateException(); - } - - if (tableName.getSchemaName().equals(TEST_SCHEMA) && tableName.getTableName().equals(TEST_TABLE)) { - return new CassandraTable( - new CassandraTableHandle(connectorId, TEST_SCHEMA, TEST_TABLE), - ImmutableList.of( - new CassandraColumnHandle(connectorId, TEST_COLUMN1, 0, CassandraType.VARCHAR, null, true, false, false, false), - new CassandraColumnHandle(connectorId, TEST_COLUMN2, 0, CassandraType.INT, null, false, false, false, false))); - } - throw new TableNotFoundException(tableName); - } - - @Override - public List getPartitions(CassandraTable table, List filterPrefix) - { - throw new UnsupportedOperationException(); - } - - @Override - public T executeWithSession(SessionCallable sessionCallable) - { - throw new UnsupportedOperationException(); - } - - @Override - public Set getReplicas(String schemaName, ByteBuffer partitionKey) - { - throw new UnsupportedOperationException(); - } -} diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCachingCassandraSchemaProvider.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCachingCassandraSchemaProvider.java deleted file mode 100644 index 1f1c8c81b3004..0000000000000 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCachingCassandraSchemaProvider.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.cassandra; - -import com.facebook.presto.spi.SchemaNotFoundException; -import com.facebook.presto.spi.TableNotFoundException; -import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.ListeningExecutorService; -import io.airlift.units.Duration; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; - -import java.util.concurrent.TimeUnit; - -import static com.facebook.presto.cassandra.MockCassandraSession.BAD_SCHEMA; -import static com.facebook.presto.cassandra.MockCassandraSession.TEST_SCHEMA; -import static com.facebook.presto.cassandra.MockCassandraSession.TEST_TABLE; -import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; -import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static java.util.concurrent.Executors.newCachedThreadPool; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertNotNull; - -@Test(singleThreaded = true) -public class TestCachingCassandraSchemaProvider -{ - private static final String CONNECTOR_ID = "test-cassandra"; - private MockCassandraSession mockSession; - private CachingCassandraSchemaProvider schemaProvider; - - @BeforeMethod - public void setUp() - throws Exception - { - mockSession = new MockCassandraSession(CONNECTOR_ID); - ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("test-%s"))); - schemaProvider = new CachingCassandraSchemaProvider( - CONNECTOR_ID, - mockSession, - executor, - new Duration(5, TimeUnit.MINUTES), - new Duration(1, TimeUnit.MINUTES)); - } - - @Test - public void testGetAllDatabases() - throws Exception - { - assertEquals(mockSession.getAccessCount(), 0); - assertEquals(schemaProvider.getAllSchemas(), ImmutableList.of(TEST_SCHEMA)); - assertEquals(mockSession.getAccessCount(), 1); - assertEquals(schemaProvider.getAllSchemas(), ImmutableList.of(TEST_SCHEMA)); - assertEquals(mockSession.getAccessCount(), 1); - - schemaProvider.flushCache(); - - assertEquals(schemaProvider.getAllSchemas(), ImmutableList.of(TEST_SCHEMA)); - assertEquals(mockSession.getAccessCount(), 2); - } - - @Test - public void testGetAllTable() - throws Exception - { - assertEquals(mockSession.getAccessCount(), 0); - assertEquals(schemaProvider.getAllTables(TEST_SCHEMA), ImmutableList.of(TEST_TABLE)); - assertEquals(mockSession.getAccessCount(), 2); - assertEquals(schemaProvider.getAllTables(TEST_SCHEMA), ImmutableList.of(TEST_TABLE)); - assertEquals(mockSession.getAccessCount(), 2); - - schemaProvider.flushCache(); - - assertEquals(schemaProvider.getAllTables(TEST_SCHEMA), ImmutableList.of(TEST_TABLE)); - assertEquals(mockSession.getAccessCount(), 4); - } - - @Test(expectedExceptions = SchemaNotFoundException.class) - public void testInvalidDbGetAllTAbles() - throws Exception - { - schemaProvider.getAllTables(BAD_SCHEMA); - } - - @Test - public void testGetTable() - throws Exception - { - CassandraTableHandle tableHandle = new CassandraTableHandle(CONNECTOR_ID, TEST_SCHEMA, TEST_TABLE); - assertEquals(mockSession.getAccessCount(), 0); - assertNotNull(schemaProvider.getTable(tableHandle)); - assertEquals(mockSession.getAccessCount(), 1); - assertNotNull(schemaProvider.getTable(tableHandle)); - assertEquals(mockSession.getAccessCount(), 1); - - schemaProvider.flushCache(); - - assertNotNull(schemaProvider.getTable(tableHandle)); - assertEquals(mockSession.getAccessCount(), 2); - } - - @Test(expectedExceptions = TableNotFoundException.class) - public void testInvalidDbGetTable() - throws Exception - { - CassandraTableHandle tableHandle = new CassandraTableHandle(CONNECTOR_ID, BAD_SCHEMA, TEST_TABLE); - schemaProvider.getTable(tableHandle); - } - - @Test - public void testNoCacheExceptions() - throws Exception - { - // Throw exceptions on usage - mockSession.setThrowException(true); - try { - schemaProvider.getAllSchemas(); - } - catch (RuntimeException ignored) { - } - assertEquals(mockSession.getAccessCount(), 1); - - // Second try should hit the client again - try { - schemaProvider.getAllSchemas(); - } - catch (RuntimeException ignored) { - } - assertEquals(mockSession.getAccessCount(), 2); - } -} diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java index 28da5db5591f4..7d9503f355b10 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java @@ -21,7 +21,6 @@ import org.testng.annotations.Test; import java.util.Map; -import java.util.concurrent.TimeUnit; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; @@ -33,9 +32,6 @@ public class TestCassandraClientConfig public void testDefaults() { ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(CassandraClientConfig.class) - .setMaxSchemaRefreshThreads(1) - .setSchemaCacheTtl(new Duration(1, TimeUnit.HOURS)) - .setSchemaRefreshInterval(new Duration(2, TimeUnit.MINUTES)) .setFetchSize(5_000) .setConsistencyLevel(ConsistencyLevel.ONE) .setContactPoints("") @@ -66,9 +62,6 @@ public void testDefaults() public void testExplicitPropertyMappings() { Map properties = new ImmutableMap.Builder() - .put("cassandra.max-schema-refresh-threads", "2") - .put("cassandra.schema-cache-ttl", "2h") - .put("cassandra.schema-refresh-interval", "30m") .put("cassandra.contact-points", "host1,host2") .put("cassandra.native-protocol-port", "9999") .put("cassandra.fetch-size", "10000") @@ -96,9 +89,6 @@ public void testExplicitPropertyMappings() .build(); CassandraClientConfig expected = new CassandraClientConfig() - .setMaxSchemaRefreshThreads(2) - .setSchemaCacheTtl(new Duration(2, TimeUnit.HOURS)) - .setSchemaRefreshInterval(new Duration(30, TimeUnit.MINUTES)) .setContactPoints("host1", "host2") .setNativeProtocolPort(9999) .setFetchSize(10_000) diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java index 1efa56d1d59a6..7506ceffc03e1 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java @@ -19,6 +19,7 @@ import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; import com.google.common.collect.ImmutableList; +import io.airlift.units.Duration; import org.joda.time.DateTime; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -32,8 +33,10 @@ import static com.datastax.driver.core.utils.Bytes.toRawHexString; import static com.facebook.presto.cassandra.CassandraQueryRunner.createCassandraSession; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_ALL_TYPES; +import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_ALL_TYPES_INSERT; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_ALL_TYPES_PARTITION_KEY; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_CLUSTERING_KEYS; +import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_CLUSTERING_KEYS_INEQUALITY; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_CLUSTERING_KEYS_LARGE; import static com.facebook.presto.cassandra.CassandraTestingUtils.TABLE_MULTI_PARTITION_CLUSTERING_KEYS; import static com.facebook.presto.cassandra.CassandraTestingUtils.createTestTables; @@ -47,7 +50,11 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.testing.MaterializedResult.DEFAULT_PRECISION; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.tests.QueryAssertions.assertContains; +import static com.facebook.presto.tests.QueryAssertions.assertContainsEventually; import static com.google.common.primitives.Ints.toByteArray; +import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.stream.Collectors.toList; import static org.joda.time.DateTimeZone.UTC; import static org.testng.Assert.assertEquals; @@ -63,6 +70,8 @@ public class TestCassandraIntegrationSmokeTest private static final Date DATE_LOCAL = new Date(DATE_TIME_UTC.getMillis()); private static final Timestamp TIMESTAMP_LOCAL = new Timestamp(DATE_TIME_UTC.getMillis()); + private CassandraSession session; + public TestCassandraIntegrationSmokeTest() throws Exception { @@ -73,7 +82,8 @@ public TestCassandraIntegrationSmokeTest() public void setUp() throws Exception { - createTestTables(EmbeddedCassandra.getSession(), KEYSPACE, DATE_LOCAL); + session = EmbeddedCassandra.getSession(); + createTestTables(session, KEYSPACE, DATE_LOCAL); } @Test @@ -193,6 +203,304 @@ public void testClusteringKeyOnlyPushdown() assertEquals(execute(sql).getRowCount(), 1); sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two='clust_two_2' AND clust_three='clust_three_2'"; assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two='clust_two_2' AND clust_three IN ('clust_three_1', 'clust_three_2', 'clust_three_3')"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three IN ('clust_three_1', 'clust_three_2', 'clust_three_3')"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two > 'clust_two_998'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two > 'clust_two_997' AND clust_two < 'clust_two_999'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three > 'clust_three_998'"; + assertEquals(execute(sql).getRowCount(), 0); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three < 'clust_three_3'"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2') AND clust_three > 'clust_three_1' AND clust_three < 'clust_three_3'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2','clust_two_3') AND clust_two < 'clust_two_2'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_997','clust_two_998','clust_two_999') AND clust_two > 'clust_two_998'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_LARGE + " WHERE clust_one='clust_one' AND clust_two IN ('clust_two_1','clust_two_2','clust_two_3') AND clust_two = 'clust_two_2'"; + assertEquals(execute(sql).getRowCount(), 1); + } + + @Test + public void testClusteringKeyPushdownInequality() + throws Exception + { + String sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one'"; + assertEquals(execute(sql).getRowCount(), 4); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2 AND clust_three = timestamp '1970-01-01 03:04:05.020'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2 AND clust_three = timestamp '1970-01-01 03:04:05.010'"; + assertEquals(execute(sql).getRowCount(), 0); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2)"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two > 1 AND clust_two < 3"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two=2 AND clust_three >= timestamp '1970-01-01 03:04:05.010' AND clust_three <= timestamp '1970-01-01 03:04:05.020'"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2) AND clust_three >= timestamp '1970-01-01 03:04:05.010' AND clust_three <= timestamp '1970-01-01 03:04:05.020'"; + assertEquals(execute(sql).getRowCount(), 2); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2,3) AND clust_two < 2"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2,3) AND clust_two > 2"; + assertEquals(execute(sql).getRowCount(), 1); + sql = "SELECT * FROM " + TABLE_CLUSTERING_KEYS_INEQUALITY + " WHERE key='key_1' AND clust_one='clust_one' AND clust_two IN (1,2,3) AND clust_two = 2"; + assertEquals(execute(sql).getRowCount(), 1); + } + + @Test + public void testUpperCaseNameUnescapedInCassandra() + throws Exception + { + /* + * If an identifier is not escaped with double quotes it is stored as lowercase in the Cassandra metadata + * + * http://docs.datastax.com/en/cql/3.1/cql/cql_reference/ucase-lcase_r.html + */ + session.execute("CREATE KEYSPACE KEYSPACE_1 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_1") + .build(), new Duration(1, MINUTES)); + + session.execute("CREATE TABLE KEYSPACE_1.TABLE_1 (COLUMN_1 bigint PRIMARY KEY)"); + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_1"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_1") + .build(), new Duration(1, MINUTES)); + assertContains(execute("SHOW COLUMNS FROM cassandra.keyspace_1.table_1"), resultBuilder(getSession(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType()) + .row("column_1", "bigint", "", "") + .build()); + + execute("INSERT INTO keyspace_1.table_1 (column_1) VALUES (1)"); + + assertEquals(execute("SELECT column_1 FROM cassandra.keyspace_1.table_1").getRowCount(), 1); + assertUpdate("DROP TABLE cassandra.keyspace_1.table_1"); + + // when an identifier is unquoted the lowercase and uppercase spelling may be used interchangeable + session.execute("DROP KEYSPACE keyspace_1"); + } + + @Test + public void testUppercaseNameEscaped() + throws Exception + { + /* + * If an identifier is escaped with double quotes it is stored verbatim + * + * http://docs.datastax.com/en/cql/3.1/cql/cql_reference/ucase-lcase_r.html + */ + session.execute("CREATE KEYSPACE \"KEYSPACE_2\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_2") + .build(), new Duration(1, MINUTES)); + + session.execute("CREATE TABLE \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\" bigint PRIMARY KEY)"); + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_2"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_2") + .build(), new Duration(1, MINUTES)); + assertContains(execute("SHOW COLUMNS FROM cassandra.keyspace_2.table_2"), resultBuilder(getSession(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType()) + .row("column_2", "bigint", "", "") + .build()); + + execute("INSERT INTO \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\") VALUES (1)"); + + assertEquals(execute("SELECT column_2 FROM cassandra.keyspace_2.table_2").getRowCount(), 1); + assertUpdate("DROP TABLE cassandra.keyspace_2.table_2"); + + // when an identifier is unquoted the lowercase and uppercase spelling may be used interchangeable + session.execute("DROP KEYSPACE \"KEYSPACE_2\""); + } + + @Test + public void testKeyspaceNameAmbiguity() + throws Exception + { + // Identifiers enclosed in double quotes are stored in Cassandra verbatim. It is possible to create 2 keyspaces with names + // that have differences only in letters case. + session.execute("CREATE KEYSPACE \"KeYsPaCe_3\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + session.execute("CREATE KEYSPACE \"kEySpAcE_3\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + + // Although in Presto all the schema and table names are always displayed as lowercase + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_3") + .row("keyspace_3") + .build(), new Duration(1, MINUTES)); + + // There is no way to figure out what the exactly keyspace we want to retrieve tables from + assertQueryFailsEventually( + "SHOW TABLES FROM cassandra.keyspace_3", + "More than one keyspace has been found for the case insensitive schema name: keyspace_3 -> \\(KeYsPaCe_3, kEySpAcE_3\\)", + new Duration(1, MINUTES)); + + session.execute("DROP KEYSPACE \"KeYsPaCe_3\""); + session.execute("DROP KEYSPACE \"kEySpAcE_3\""); + } + + @Test + public void testTableNameAmbiguity() + throws Exception + { + session.execute("CREATE KEYSPACE keyspace_4 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_4") + .build(), new Duration(1, MINUTES)); + + // Identifiers enclosed in double quotes are stored in Cassandra verbatim. It is possible to create 2 tables with names + // that have differences only in letters case. + session.execute("CREATE TABLE keyspace_4.\"TaBlE_4\" (column_4 bigint PRIMARY KEY)"); + session.execute("CREATE TABLE keyspace_4.\"tAbLe_4\" (column_4 bigint PRIMARY KEY)"); + + // Although in Presto all the schema and table names are always displayed as lowercase + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_4"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_4") + .row("table_4") + .build(), new Duration(1, MINUTES)); + + // There is no way to figure out what the exactly table is being queried + assertQueryFailsEventually( + "SHOW COLUMNS FROM cassandra.keyspace_4.table_4", + "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", + new Duration(1, MINUTES)); + assertQueryFailsEventually( + "SELECT * FROM cassandra.keyspace_4.table_4", + "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", + new Duration(1, MINUTES)); + session.execute("DROP KEYSPACE keyspace_4"); + } + + @Test + public void testColumnNameAmbiguity() + throws Exception + { + session.execute("CREATE KEYSPACE keyspace_5 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_5") + .build(), new Duration(1, MINUTES)); + + session.execute("CREATE TABLE keyspace_5.table_5 (\"CoLuMn_5\" bigint PRIMARY KEY, \"cOlUmN_5\" bigint)"); + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_5"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_5") + .build(), new Duration(1, MINUTES)); + + assertQueryFailsEventually( + "SHOW COLUMNS FROM cassandra.keyspace_5.table_5", + "More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", + new Duration(1, MINUTES)); + assertQueryFailsEventually( + "SELECT * FROM cassandra.keyspace_5.table_5", + "More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", + new Duration(1, MINUTES)); + + session.execute("DROP KEYSPACE keyspace_5"); + } + + @Test + public void testInsert() + { + String sql = "SELECT key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal, " + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + " FROM " + TABLE_ALL_TYPES_INSERT; + assertEquals(execute(sql).getRowCount(), 0); + + // TODO Following types are not supported now. We need to change null into the value after fixing it + // blob, frozen>, inet, list, map, set, timeuuid, decimal, uuid, varint + // timestamp can be inserted but the expected and actual values are not same + execute("INSERT INTO " + TABLE_ALL_TYPES_INSERT + " (" + + "key," + + "typeuuid," + + "typeinteger," + + "typelong," + + "typebytes," + + "typetimestamp," + + "typeansi," + + "typeboolean," + + "typedecimal," + + "typedouble," + + "typefloat," + + "typeinet," + + "typevarchar," + + "typevarint," + + "typetimeuuid," + + "typelist," + + "typemap," + + "typeset" + + ") VALUES (" + + "'key1', " + + "null, " + + "1, " + + "1000, " + + "null, " + + "timestamp '1970-01-01 08:34:05.0', " + + "'ansi1', " + + "true, " + + "null, " + + "0.3, " + + "cast('0.4' as real), " + + "null, " + + "'varchar1', " + + "null, " + + "null, " + + "null, " + + "null, " + + "null " + + ")"); + + MaterializedResult result = execute(sql); + int rowCount = result.getRowCount(); + assertEquals(rowCount, 1); + assertEquals(result.getMaterializedRows().get(0), new MaterializedRow(DEFAULT_PRECISION, + "key1", + null, + 1, + 1000L, + null, + Timestamp.valueOf("1970-01-01 14:04:05.0"), + "ansi1", + true, + null, + 0.3, + (float) 0.4, + null, + "varchar1", + null, + null, + null, + null, + null + )); + + // insert null for all datatypes + execute("INSERT INTO " + TABLE_ALL_TYPES_INSERT + " (" + + "key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal," + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + ") VALUES (" + + "'key2', null, null, null, null, null, null, null, null," + + "null, null, null, null, null, null, null, null, null)"); + sql = "SELECT key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal, " + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + " FROM " + TABLE_ALL_TYPES_INSERT + " WHERE key = 'key2'"; + result = execute(sql); + rowCount = result.getRowCount(); + assertEquals(rowCount, 1); + assertEquals(result.getMaterializedRows().get(0), new MaterializedRow(DEFAULT_PRECISION, + "key2", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)); + + // insert into only a subset of columns + execute("INSERT INTO " + TABLE_ALL_TYPES_INSERT + " (" + + "key, typeinteger, typeansi, typeboolean) VALUES (" + + "'key3', 999, 'ansi', false)"); + sql = "SELECT key, typeuuid, typeinteger, typelong, typebytes, typetimestamp, typeansi, typeboolean, typedecimal, " + + "typedouble, typefloat, typeinet, typevarchar, typevarint, typetimeuuid, typelist, typemap, typeset" + + " FROM " + TABLE_ALL_TYPES_INSERT + " WHERE key = 'key3'"; + result = execute(sql); + rowCount = result.getRowCount(); + assertEquals(rowCount, 1); + assertEquals(result.getMaterializedRows().get(0), new MaterializedRow(DEFAULT_PRECISION, + "key3", null, 999, null, null, null, "ansi", false, null, null, null, null, null, null, null, null, null, null)); } private void assertSelect(String tableName, boolean createdByPresto) diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java index 0a84fd5630864..367ede2cf3900 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java @@ -14,7 +14,7 @@ package com.facebook.presto.cassandra; import com.facebook.presto.cassandra.CassandraTokenSplitManager.TokenSplit; -import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.List; @@ -27,42 +27,46 @@ public class TestCassandraTokenSplitManager { private static final int SPLIT_SIZE = 100; private static final String KEYSPACE = "test_cassandra_token_split_manager_keyspace"; - private static final String TABLE = "test_cassandra_token_split_manager_table"; private static final int PARTITION_COUNT = 1000; private CassandraSession session; private CassandraTokenSplitManager splitManager; - @BeforeMethod + @BeforeClass public void setUp() throws Exception { EmbeddedCassandra.start(); session = EmbeddedCassandra.getSession(); + createKeyspace(session, KEYSPACE); splitManager = new CassandraTokenSplitManager(session, SPLIT_SIZE); } @Test - public void testCassandraTokenSplitManager() + public void testEmptyTable() throws Exception { - createKeyspace(session, KEYSPACE); - session.execute(format("CREATE TABLE %s.%s (key text PRIMARY KEY)", KEYSPACE, TABLE)); - - EmbeddedCassandra.flush(KEYSPACE, TABLE); - EmbeddedCassandra.refreshSizeEstimates(); - - List splits = splitManager.getSplits(KEYSPACE, TABLE); + String tableName = "empty_table"; + session.execute(format("CREATE TABLE %s.%s (key text PRIMARY KEY)", KEYSPACE, tableName)); + EmbeddedCassandra.refreshSizeEstimates(KEYSPACE, tableName); + List splits = splitManager.getSplits(KEYSPACE, tableName); // even for the empty table at least one split must be produced, in case the statistics are inaccurate assertEquals(splits.size(), 1); + session.execute(format("DROP TABLE %s.%s", KEYSPACE, tableName)); + } + @Test + public void testNonEmptyTable() + throws Exception + { + String tableName = "non_empty_table"; + session.execute(format("CREATE TABLE %s.%s (key text PRIMARY KEY)", KEYSPACE, tableName)); for (int i = 0; i < PARTITION_COUNT; i++) { - session.execute(format("INSERT INTO %s.%s (key) VALUES ('%s')", KEYSPACE, TABLE, "value" + i)); + session.execute(format("INSERT INTO %s.%s (key) VALUES ('%s')", KEYSPACE, tableName, "value" + i)); } - EmbeddedCassandra.flush(KEYSPACE, TABLE); - EmbeddedCassandra.refreshSizeEstimates(); - - splits = splitManager.getSplits(KEYSPACE, TABLE); + EmbeddedCassandra.refreshSizeEstimates(KEYSPACE, tableName); + List splits = splitManager.getSplits(KEYSPACE, tableName); assertEquals(splits.size(), PARTITION_COUNT / SPLIT_SIZE); + session.execute(format("DROP TABLE %s.%s", KEYSPACE, tableName)); } } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java index 5a56a88d3d473..78e5c82610b7d 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraClusteringPredicatesExtractor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.cassandra.util; +import com.datastax.driver.core.VersionNumber; import com.facebook.presto.cassandra.CassandraClusteringPredicatesExtractor; import com.facebook.presto.cassandra.CassandraColumnHandle; import com.facebook.presto.cassandra.CassandraTable; @@ -26,8 +27,6 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; -import java.util.List; - import static com.facebook.presto.spi.type.BigintType.BIGINT; import static org.testng.Assert.assertEquals; @@ -38,6 +37,7 @@ public class TestCassandraClusteringPredicatesExtractor private static CassandraColumnHandle col3; private static CassandraColumnHandle col4; private static CassandraTable cassandraTable; + private static VersionNumber cassandraVersion; @BeforeTest void setUp() @@ -50,6 +50,8 @@ void setUp() cassandraTable = new CassandraTable( new CassandraTableHandle("cassandra", "test", "records"), ImmutableList.of(col1, col2, col3, col4)); + + cassandraVersion = VersionNumber.parse("2.1.5"); } @Test @@ -60,9 +62,9 @@ public void testBuildClusteringPredicate() col1, Domain.singleValue(BIGINT, 23L), col2, Domain.singleValue(BIGINT, 34L), col4, Domain.singleValue(BIGINT, 26L))); - CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain); - List predicate = predicatesExtractor.getClusteringKeyPredicates(); - assertEquals(predicate.get(0), new StringBuilder("\"clusteringKey1\" = 34").toString()); + CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); + String predicate = predicatesExtractor.getClusteringKeyPredicates(); + assertEquals(predicate, new StringBuilder("\"clusteringKey1\" = 34").toString()); } @Test @@ -72,7 +74,7 @@ public void testGetUnenforcedPredicates() ImmutableMap.of( col2, Domain.singleValue(BIGINT, 34L), col4, Domain.singleValue(BIGINT, 26L))); - CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain); + CassandraClusteringPredicatesExtractor predicatesExtractor = new CassandraClusteringPredicatesExtractor(cassandraTable.getClusteringKeyColumns(), tupleDomain, cassandraVersion); TupleDomain unenforcedPredicates = TupleDomain.withColumnDomains(ImmutableMap.of(col4, Domain.singleValue(BIGINT, 26L))); assertEquals(predicatesExtractor.getUnenforcedConstraints(), unenforcedPredicates); } diff --git a/presto-cli/pom.xml b/presto-cli/pom.xml index 0bee905c27e36..b890075daed90 100644 --- a/presto-cli/pom.xml +++ b/presto-cli/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-cli @@ -14,8 +14,6 @@ ${project.parent.basedir} com.facebook.presto.cli.Presto - false - ${main-class} @@ -39,16 +37,6 @@ concurrent - - io.airlift - http-client - - - - io.airlift - json - - io.airlift log @@ -89,6 +77,11 @@ opencsv + + com.squareup.okhttp3 + okhttp + + org.testng diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java b/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java index 819fc938b0b09..12f5b67217df1 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java @@ -18,10 +18,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.net.HostAndPort; import io.airlift.airline.Option; -import io.airlift.http.client.spnego.KerberosConfig; import io.airlift.units.Duration; -import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.CharsetEncoder; @@ -33,6 +31,7 @@ import java.util.Optional; import java.util.TimeZone; +import static com.facebook.presto.client.KerberosUtil.defaultCredentialCachePath; import static com.google.common.base.Preconditions.checkArgument; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.util.Collections.emptyMap; @@ -58,7 +57,7 @@ public class ClientOptions public String krb5KeytabPath = "/etc/krb5.keytab"; @Option(name = "--krb5-credential-cache-path", title = "krb5 credential cache path", description = "Kerberos credential cache path") - public String krb5CredentialCachePath = defaultCredentialCachePath(); + public String krb5CredentialCachePath = defaultCredentialCachePath().orElse(null); @Option(name = "--krb5-principal", title = "krb5 principal", description = "Kerberos principal to be used") public String krb5Principal; @@ -114,6 +113,9 @@ public class ClientOptions @Option(name = "--socks-proxy", title = "socks-proxy", description = "SOCKS proxy to use for server connections") public HostAndPort socksProxy; + @Option(name = "--http-proxy", title = "http-proxy", description = "HTTP proxy to use for server connections") + public HostAndPort httpProxy; + @Option(name = "--client-request-timeout", title = "client request timeout", description = "Client request timeout (default: 2m)") public Duration clientRequestTimeout = new Duration(2, MINUTES); @@ -146,22 +148,6 @@ public ClientSession toClientSession() clientRequestTimeout); } - public KerberosConfig toKerberosConfig() - { - KerberosConfig config = new KerberosConfig(); - if (krb5ConfigPath != null) { - config.setConfig(new File(krb5ConfigPath)); - } - if (krb5KeytabPath != null) { - config.setKeytab(new File(krb5KeytabPath)); - } - if (krb5CredentialCachePath != null) { - config.setCredentialCache(new File(krb5CredentialCachePath)); - } - config.setUseCanonicalHostname(!krb5DisableRemoteServiceHostnameCanonicalization); - return config; - } - public static URI parseServer(String server) { server = server.toLowerCase(ENGLISH); @@ -171,7 +157,7 @@ public static URI parseServer(String server) HostAndPort host = HostAndPort.fromString(server); try { - return new URI("http", null, host.getHostText(), host.getPortOrDefault(80), null, null, null); + return new URI("http", null, host.getHost(), host.getPortOrDefault(80), null, null, null); } catch (URISyntaxException e) { throw new IllegalArgumentException(e); @@ -191,15 +177,6 @@ public static Map toProperties(List sessi return builder.build(); } - private static String defaultCredentialCachePath() - { - String value = System.getenv("KRB5CCNAME"); - if (value != null && value.startsWith("FILE:")) { - return value.substring("FILE:".length()); - } - return value; - } - public static final class ClientSessionProperty { private static final Splitter NAME_VALUE_SPLITTER = Splitter.on('=').limit(2); diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/Console.java b/presto-cli/src/main/java/com/facebook/presto/cli/Console.java index 982ea5ed4c7c0..ef72bdad3d7ac 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/Console.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/Console.java @@ -27,7 +27,6 @@ import com.google.common.io.Files; import io.airlift.airline.Command; import io.airlift.airline.HelpOption; -import io.airlift.http.client.spnego.KerberosConfig; import io.airlift.log.Logging; import io.airlift.log.LoggingConfiguration; import io.airlift.units.Duration; @@ -51,6 +50,7 @@ import static com.facebook.presto.cli.Completion.commandCompleter; import static com.facebook.presto.cli.Completion.lowerCaseCommandCompleter; import static com.facebook.presto.cli.Help.getHelpText; +import static com.facebook.presto.cli.QueryPreprocessor.preprocessQuery; import static com.facebook.presto.client.ClientSession.stripTransactionId; import static com.facebook.presto.client.ClientSession.withCatalogAndSchema; import static com.facebook.presto.client.ClientSession.withPreparedStatements; @@ -93,7 +93,6 @@ public class Console public void run() { ClientSession session = clientOptions.toClientSession(); - KerberosConfig kerberosConfig = clientOptions.toKerberosConfig(); boolean hasQuery = !Strings.isNullOrEmpty(clientOptions.execute); boolean isFromFile = !Strings.isNullOrEmpty(clientOptions.file); @@ -124,9 +123,10 @@ public void run() AtomicBoolean exiting = new AtomicBoolean(); interruptThreadOnExit(Thread.currentThread(), exiting); - try (QueryRunner queryRunner = QueryRunner.create( + try (QueryRunner queryRunner = new QueryRunner( session, Optional.ofNullable(clientOptions.socksProxy), + Optional.ofNullable(clientOptions.httpProxy), Optional.ofNullable(clientOptions.keystorePath), Optional.ofNullable(clientOptions.keystorePassword), Optional.ofNullable(clientOptions.truststorePath), @@ -135,8 +135,11 @@ public void run() clientOptions.password ? Optional.of(getPassword()) : Optional.empty(), Optional.ofNullable(clientOptions.krb5Principal), Optional.ofNullable(clientOptions.krb5RemoteServiceName), - clientOptions.authenticationEnabled, - kerberosConfig)) { + Optional.ofNullable(clientOptions.krb5ConfigPath), + Optional.ofNullable(clientOptions.krb5KeytabPath), + Optional.ofNullable(clientOptions.krb5CredentialCachePath), + !clientOptions.krb5DisableRemoteServiceHostnameCanonicalization, + clientOptions.authenticationEnabled)) { if (hasQuery) { executeCommand(queryRunner, query, clientOptions.outputFormat); } @@ -314,7 +317,22 @@ private static void executeCommand(QueryRunner queryRunner, String query, Output private static void process(QueryRunner queryRunner, String sql, OutputFormat outputFormat, boolean interactive) { - try (Query query = queryRunner.startQuery(sql)) { + String finalSql; + try { + finalSql = preprocessQuery( + Optional.ofNullable(queryRunner.getSession().getCatalog()), + Optional.ofNullable(queryRunner.getSession().getSchema()), + sql); + } + catch (QueryPreprocessorException e) { + System.err.println(e.getMessage()); + if (queryRunner.getSession().isDebug()) { + e.printStackTrace(); + } + return; + } + + try (Query query = queryRunner.startQuery(finalSql)) { query.renderOutput(System.out, outputFormat, interactive); ClientSession session = queryRunner.getSession(); diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/LdapRequestFilter.java b/presto-cli/src/main/java/com/facebook/presto/cli/LdapRequestFilter.java deleted file mode 100644 index de40f0e32cc88..0000000000000 --- a/presto-cli/src/main/java/com/facebook/presto/cli/LdapRequestFilter.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.cli; - -import com.google.common.net.HttpHeaders; -import io.airlift.http.client.HttpRequestFilter; -import io.airlift.http.client.Request; - -import java.util.Base64; - -import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.http.client.Request.Builder.fromRequest; -import static java.nio.charset.StandardCharsets.ISO_8859_1; -import static java.util.Objects.requireNonNull; - -public class LdapRequestFilter - implements HttpRequestFilter -{ - private final String user; - private final String password; - - public LdapRequestFilter(String user, String password) - { - this.user = requireNonNull(user, "user is null"); - checkArgument(!user.contains(":"), "Illegal character ':' found in username"); - this.password = requireNonNull(password, "password is null"); - } - - @Override - public Request filterRequest(Request request) - { - String value = "Basic " + Base64.getEncoder().encodeToString((user + ":" + password).getBytes(ISO_8859_1)); - return fromRequest(request) - .addHeader(HttpHeaders.AUTHORIZATION, value) - .build(); - } -} diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessor.java b/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessor.java new file mode 100644 index 0000000000000..2ce2a99a16180 --- /dev/null +++ b/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessor.java @@ -0,0 +1,213 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cli; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.io.CharStreams; +import io.airlift.units.Duration; +import sun.misc.Signal; +import sun.misc.SignalHandler; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.presto.cli.ConsolePrinter.REAL_TERMINAL; +import static com.google.common.base.Strings.emptyToNull; +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.base.Throwables.propagateIfPossible; +import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +public final class QueryPreprocessor +{ + public static final String ENV_PREPROCESSOR = "PRESTO_PREPROCESSOR"; + public static final String ENV_PREPROCESSOR_TIMEOUT = "PRESTO_PREPROCESSOR_TIMEOUT"; + public static final String ENV_PRESTO_CATALOG = "PRESTO_CATALOG"; + public static final String ENV_PRESTO_SCHEMA = "PRESTO_SCHEMA"; + private static final Duration DEFAULT_PREPROCESSOR_TIMEOUT = new Duration(10, SECONDS); + + private static final Signal SIGINT = new Signal("INT"); + private static final String PREPROCESSING_QUERY_MESSAGE = "Preprocessing query..."; + + private QueryPreprocessor() {} + + public static String preprocessQuery(Optional catalog, Optional schema, String query) + throws QueryPreprocessorException + { + Duration timeout = DEFAULT_PREPROCESSOR_TIMEOUT; + String timeoutEnvironment = nullToEmpty(System.getenv(ENV_PREPROCESSOR_TIMEOUT)).trim(); + if (!timeoutEnvironment.isEmpty()) { + timeout = Duration.valueOf(timeoutEnvironment); + } + + String preprocessorCommand = System.getenv(ENV_PREPROCESSOR); + if (emptyToNull(preprocessorCommand) == null) { + return query; + } + return preprocessQuery(catalog, schema, query, ImmutableList.of("/bin/sh", "-c", preprocessorCommand), timeout); + } + + public static String preprocessQuery(Optional catalog, Optional schema, String query, List preprocessorCommand, Duration timeout) + throws QueryPreprocessorException + { + Thread clientThread = Thread.currentThread(); + SignalHandler oldHandler = Signal.handle(SIGINT, signal -> clientThread.interrupt()); + try { + if (REAL_TERMINAL) { + System.out.print(PREPROCESSING_QUERY_MESSAGE); + System.out.flush(); + } + return preprocessQueryInternal(catalog, schema, query, preprocessorCommand, timeout); + } + finally { + if (REAL_TERMINAL) { + System.out.print("\r" + Strings.repeat(" ", PREPROCESSING_QUERY_MESSAGE.length()) + "\r"); + System.out.flush(); + } + Signal.handle(SIGINT, oldHandler); + Thread.interrupted(); // clear interrupt status + } + } + + private static String preprocessQueryInternal(Optional catalog, Optional schema, String query, List preprocessorCommand, Duration timeout) + throws QueryPreprocessorException + { + // execute the process in a child thread so we can better handle interruption and timeouts + AtomicReference processReference = new AtomicReference<>(); + + Future task = executeInNewThread("Query preprocessor", () -> { + String result; + int exitCode; + Future readStderr; + try { + ProcessBuilder processBuilder = new ProcessBuilder(preprocessorCommand); + processBuilder.environment().put(ENV_PRESTO_CATALOG, catalog.orElse("")); + processBuilder.environment().put(ENV_PRESTO_SCHEMA, schema.orElse("")); + + Process process = processBuilder.start(); + processReference.set(process); + + Future writeOutput = null; + try { + // write query to process standard out + writeOutput = executeInNewThread("Query preprocessor output", () -> { + try (OutputStream outputStream = process.getOutputStream()) { + outputStream.write(query.getBytes(UTF_8)); + } + return null; + }); + + // read stderr + readStderr = executeInNewThread("Query preprocessor read stderr", () -> { + StringBuilder builder = new StringBuilder(); + try (InputStream inputStream = process.getErrorStream()) { + CharStreams.copy(new InputStreamReader(inputStream, UTF_8), builder); + } + catch (IOException | RuntimeException ignored) { + } + return builder.toString(); + }); + + // read response + try (InputStream inputStream = process.getInputStream()) { + result = CharStreams.toString(new InputStreamReader(inputStream, UTF_8)); + } + + // verify output was written successfully + try { + writeOutput.get(); + } + catch (ExecutionException e) { + throw e.getCause(); + } + + // wait for process to finish + exitCode = process.waitFor(); + } + finally { + process.destroyForcibly(); + if (writeOutput != null) { + writeOutput.cancel(true); + } + } + } + catch (QueryPreprocessorException e) { + throw e; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new QueryPreprocessorException("Interrupted while preprocessing query"); + } + catch (Throwable e) { + throw new QueryPreprocessorException("Error preprocessing query: " + e.getMessage(), e); + } + + // check we got a valid exit code + if (exitCode != 0) { + Optional errorMessage = tryGetFutureValue(readStderr, 100, MILLISECONDS) + .flatMap(value -> Optional.ofNullable(emptyToNull(value.trim()))); + + throw new QueryPreprocessorException("Query preprocessor exited " + exitCode + + errorMessage.map(message1 -> "\n===\n" + message1 + "\n===").orElse("")); + } + return result; + }); + + try { + return task.get(timeout.toMillis(), MILLISECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new QueryPreprocessorException("Interrupted while preprocessing query"); + } + catch (ExecutionException e) { + Throwable cause = e.getCause(); + propagateIfPossible(cause, QueryPreprocessorException.class); + throw new QueryPreprocessorException("Error preprocessing query: " + cause.getMessage(), cause); + } + catch (TimeoutException e) { + throw new QueryPreprocessorException("Timed out waiting for query preprocessor after " + timeout); + } + finally { + Process process = processReference.get(); + if (process != null) { + process.destroyForcibly(); + } + task.cancel(true); + } + } + + private static Future executeInNewThread(String threadName, Callable callable) + { + FutureTask task = new FutureTask<>(callable); + Thread thread = new Thread(task); + thread.setName(threadName); + thread.setDaemon(true); + thread.start(); + return task; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingTableHandle.java b/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessorException.java similarity index 59% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/TestingTableHandle.java rename to presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessorException.java index 1db8042dfa504..62f4505f916f7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingTableHandle.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessorException.java @@ -11,19 +11,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner; +package com.facebook.presto.cli; -import com.facebook.presto.spi.ConnectorTableHandle; -import com.fasterxml.jackson.annotation.JsonValue; - -public class TestingTableHandle - implements ConnectorTableHandle +public class QueryPreprocessorException + extends Exception { - // Jackson refuses to serialize this class otherwise because it's empty. - @JsonValue - @Override - public String toString() + public QueryPreprocessorException(String message) + { + super(message); + } + + public QueryPreprocessorException(String message, Throwable cause) { - return getClass().getSimpleName(); + super(message, cause); } } diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java b/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java index 80fa038eafabc..83031ba99574b 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/QueryRunner.java @@ -13,39 +13,37 @@ */ package com.facebook.presto.cli; +import com.facebook.presto.client.ClientException; import com.facebook.presto.client.ClientSession; -import com.facebook.presto.client.QueryResults; import com.facebook.presto.client.StatementClient; -import com.google.common.collect.ImmutableList; import com.google.common.net.HostAndPort; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClientConfig; -import io.airlift.http.client.HttpRequestFilter; -import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.http.client.spnego.KerberosConfig; -import io.airlift.json.JsonCodec; -import io.airlift.units.Duration; +import okhttp3.OkHttpClient; import java.io.Closeable; +import java.io.File; import java.util.Optional; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static com.facebook.presto.client.OkHttpUtil.basicAuth; +import static com.facebook.presto.client.OkHttpUtil.setupHttpProxy; +import static com.facebook.presto.client.OkHttpUtil.setupKerberos; +import static com.facebook.presto.client.OkHttpUtil.setupSocksProxy; +import static com.facebook.presto.client.OkHttpUtil.setupSsl; +import static com.facebook.presto.client.OkHttpUtil.setupTimeouts; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.json.JsonCodec.jsonCodec; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; public class QueryRunner implements Closeable { - private final JsonCodec queryResultsCodec; private final AtomicReference session; - private final HttpClient httpClient; + private final OkHttpClient httpClient; public QueryRunner( ClientSession session, - JsonCodec queryResultsCodec, Optional socksProxy, + Optional httpProxy, Optional keystorePath, Optional keystorePassword, Optional truststorePath, @@ -54,24 +52,34 @@ public QueryRunner( Optional password, Optional kerberosPrincipal, Optional kerberosRemoteServiceName, - boolean authenticationEnabled, - KerberosConfig kerberosConfig) + Optional kerberosConfigPath, + Optional kerberosKeytabPath, + Optional kerberosCredentialCachePath, + boolean kerberosUseCanonicalHostname, + boolean kerberosEnabled) { this.session = new AtomicReference<>(requireNonNull(session, "session is null")); - this.queryResultsCodec = requireNonNull(queryResultsCodec, "queryResultsCodec is null"); - this.httpClient = new JettyHttpClient( - getHttpClientConfig( - socksProxy, - keystorePath, - keystorePassword, - truststorePath, - truststorePassword, - kerberosPrincipal, - kerberosRemoteServiceName, - authenticationEnabled), - kerberosConfig, - Optional.empty(), - getRequestFilters(session, user, password)); + + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + + setupTimeouts(builder, 5, SECONDS); + setupSocksProxy(builder, socksProxy); + setupHttpProxy(builder, httpProxy); + setupSsl(builder, keystorePath, keystorePassword, truststorePath, truststorePassword); + setupBasicAuth(builder, session, user, password); + + if (kerberosEnabled) { + setupKerberos( + builder, + kerberosRemoteServiceName.orElseThrow(() -> new ClientException("Kerberos remote service name must be set")), + kerberosUseCanonicalHostname, + kerberosPrincipal, + kerberosConfigPath.map(File::new), + kerberosKeytabPath.map(File::new), + kerberosCredentialCachePath.map(File::new)); + } + + this.httpClient = builder.build(); } public ClientSession getSession() @@ -91,79 +99,26 @@ public Query startQuery(String query) public StatementClient startInternalQuery(String query) { - return new StatementClient(httpClient, queryResultsCodec, session.get(), query); + return new StatementClient(httpClient, session.get(), query); } @Override public void close() { - httpClient.close(); + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); } - public static QueryRunner create( + private static void setupBasicAuth( + OkHttpClient.Builder clientBuilder, ClientSession session, - Optional socksProxy, - Optional keystorePath, - Optional keystorePassword, - Optional truststorePath, - Optional truststorePassword, Optional user, - Optional password, - Optional kerberosPrincipal, - Optional kerberosRemoteServiceName, - boolean authenticationEnabled, - KerberosConfig kerberosConfig) - { - return new QueryRunner( - session, - jsonCodec(QueryResults.class), - socksProxy, - keystorePath, - keystorePassword, - truststorePath, - truststorePassword, - user, - password, - kerberosPrincipal, - kerberosRemoteServiceName, - authenticationEnabled, - kerberosConfig); - } - - private static HttpClientConfig getHttpClientConfig( - Optional socksProxy, - Optional keystorePath, - Optional keystorePassword, - Optional truststorePath, - Optional truststorePassword, - Optional kerberosPrincipal, - Optional kerberosRemoteServiceName, - boolean authenticationEnabled) - { - HttpClientConfig httpClientConfig = new HttpClientConfig() - .setConnectTimeout(new Duration(5, TimeUnit.SECONDS)) - .setRequestTimeout(new Duration(5, TimeUnit.SECONDS)); - - socksProxy.ifPresent(httpClientConfig::setSocksProxy); - - httpClientConfig.setAuthenticationEnabled(authenticationEnabled); - - keystorePath.ifPresent(httpClientConfig::setKeyStorePath); - keystorePassword.ifPresent(httpClientConfig::setKeyStorePassword); - truststorePath.ifPresent(httpClientConfig::setTrustStorePath); - truststorePassword.ifPresent(httpClientConfig::setTrustStorePassword); - kerberosPrincipal.ifPresent(httpClientConfig::setKerberosPrincipal); - kerberosRemoteServiceName.ifPresent(httpClientConfig::setKerberosRemoteServiceName); - - return httpClientConfig; - } - - private static Iterable getRequestFilters(ClientSession session, Optional user, Optional password) + Optional password) { if (user.isPresent() && password.isPresent()) { - checkArgument(session.getServer().getScheme().equalsIgnoreCase("https"), "Authentication using username/password requires HTTPS to be enabled"); - return ImmutableList.of(new LdapRequestFilter(user.get(), password.get())); + checkArgument(session.getServer().getScheme().equalsIgnoreCase("https"), + "Authentication using username/password requires HTTPS to be enabled"); + clientBuilder.addInterceptor(basicAuth(user.get(), password.get())); } - return ImmutableList.of(); } } diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java b/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java index 7d45a47cb7ae3..80f65a834ec2d 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java @@ -99,7 +99,7 @@ public void printInitialStatusUpdates() // check for keyboard input int key = readKey(); if (key == CTRL_P) { - partialCancel(); + client.cancelLeafStage(); } else if (key == CTRL_C) { updateScreen(); @@ -406,16 +406,6 @@ private void printStageTree(StageStats stage, String indent, AtomicInteger stage } } - private void partialCancel() - { - try { - client.cancelLeafStage(new Duration(1, SECONDS)); - } - catch (RuntimeException e) { - log.debug(e, "error canceling leaf stage"); - } - } - private void reprintLine(String line) { console.reprintLine(line); diff --git a/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java b/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java index 3c1bf78c27cda..e9f4275cfd64c 100644 --- a/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java +++ b/presto-cli/src/test/java/com/facebook/presto/cli/TestTableNameCompleter.java @@ -27,7 +27,12 @@ public class TestTableNameCompleter public void testAutoCompleteWithoutSchema() { ClientSession session = new ClientOptions().toClientSession(); - QueryRunner runner = QueryRunner.create(session, + QueryRunner runner = new QueryRunner( + session, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), @@ -38,7 +43,7 @@ public void testAutoCompleteWithoutSchema() Optional.empty(), Optional.empty(), false, - null); + false); TableNameCompleter completer = new TableNameCompleter(runner); assertEquals(completer.complete("SELECT is_infi", 14, ImmutableList.of()), 7); } diff --git a/presto-client/pom.xml b/presto-client/pom.xml index 0dae37aa67d4a..f0e1b7c6cdcfc 100644 --- a/presto-client/pom.xml +++ b/presto-client/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-client @@ -46,11 +46,6 @@ jackson-databind - - io.airlift - http-client - - io.airlift json @@ -66,6 +61,11 @@ guava + + com.squareup.okhttp3 + okhttp + + org.testng diff --git a/presto-client/src/main/java/com/facebook/presto/client/ClientException.java b/presto-client/src/main/java/com/facebook/presto/client/ClientException.java new file mode 100644 index 0000000000000..e9cc75ce70f48 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/ClientException.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +public class ClientException + extends RuntimeException +{ + public ClientException(String message) + { + super(message); + } + + public ClientException(String message, Throwable cause) + { + super(message, cause); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java b/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java new file mode 100644 index 0000000000000..12609e66296a3 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import io.airlift.json.JsonCodec; +import okhttp3.Headers; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.net.HttpHeaders.LOCATION; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public final class JsonResponse +{ + private final int statusCode; + private final String statusMessage; + private final Headers headers; + private final String responseBody; + private final boolean hasValue; + private final T value; + private final IllegalArgumentException exception; + + private JsonResponse(int statusCode, String statusMessage, Headers headers, String responseBody) + { + this.statusCode = statusCode; + this.statusMessage = statusMessage; + this.headers = requireNonNull(headers, "headers is null"); + this.responseBody = requireNonNull(responseBody, "responseBody is null"); + + this.hasValue = false; + this.value = null; + this.exception = null; + } + + private JsonResponse(int statusCode, String statusMessage, Headers headers, String responseBody, JsonCodec jsonCodec) + { + this.statusCode = statusCode; + this.statusMessage = statusMessage; + this.headers = requireNonNull(headers, "headers is null"); + this.responseBody = requireNonNull(responseBody, "responseBody is null"); + + T value = null; + IllegalArgumentException exception = null; + try { + value = jsonCodec.fromJson(responseBody); + } + catch (IllegalArgumentException e) { + exception = new IllegalArgumentException(format("Unable to create %s from JSON response:\n[%s]", jsonCodec.getType(), responseBody), e); + } + this.hasValue = (exception == null); + this.value = value; + this.exception = exception; + } + + public int getStatusCode() + { + return statusCode; + } + + public String getStatusMessage() + { + return statusMessage; + } + + public Headers getHeaders() + { + return headers; + } + + public boolean hasValue() + { + return hasValue; + } + + public T getValue() + { + if (!hasValue) { + throw new IllegalStateException("Response does not contain a JSON value", exception); + } + return value; + } + + public String getResponseBody() + { + return responseBody; + } + + @Nullable + public IllegalArgumentException getException() + { + return exception; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("statusCode", statusCode) + .add("statusMessage", statusMessage) + .add("headers", headers.toMultimap()) + .add("hasValue", hasValue) + .add("value", value) + .omitNullValues() + .toString(); + } + + public static JsonResponse execute(JsonCodec codec, OkHttpClient client, Request request) + { + try (Response response = client.newCall(request).execute()) { + // TODO: fix in OkHttp: https://github.com/square/okhttp/issues/3111 + if ((response.code() == 307) || (response.code() == 308)) { + String location = response.header(LOCATION); + if (location != null) { + request = request.newBuilder().url(location).build(); + return execute(codec, client, request); + } + } + + ResponseBody responseBody = requireNonNull(response.body()); + String body = responseBody.string(); + if (isJson(responseBody.contentType())) { + return new JsonResponse<>(response.code(), response.message(), response.headers(), body, codec); + } + return new JsonResponse<>(response.code(), response.message(), response.headers(), body); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static boolean isJson(MediaType type) + { + return (type != null) && "application".equals(type.type()) && "json".equals(type.subtype()); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/KerberosUtil.java b/presto-client/src/main/java/com/facebook/presto/client/KerberosUtil.java new file mode 100644 index 0000000000000..914def58b2024 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/KerberosUtil.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import java.util.Optional; + +import static com.google.common.base.Strings.emptyToNull; +import static com.google.common.base.Strings.nullToEmpty; + +public final class KerberosUtil +{ + private static final String FILE_PREFIX = "FILE:"; + + private KerberosUtil() {} + + public static Optional defaultCredentialCachePath() + { + String value = nullToEmpty(System.getenv("KRB5CCNAME")); + if (value.startsWith(FILE_PREFIX)) { + value = value.substring(FILE_PREFIX.length()); + } + return Optional.ofNullable(emptyToNull(value)); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java b/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java new file mode 100644 index 0000000000000..4b788a56535ca --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import com.google.common.net.HostAndPort; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.Credentials; +import okhttp3.Interceptor; +import okhttp3.OkHttpClient; +import okhttp3.Response; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.util.Arrays; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static com.google.common.net.HttpHeaders.USER_AGENT; +import static java.net.Proxy.Type.HTTP; +import static java.net.Proxy.Type.SOCKS; +import static java.util.Objects.requireNonNull; + +public final class OkHttpUtil +{ + private OkHttpUtil() {} + + public static class NullCallback + implements Callback + { + @Override + public void onFailure(Call call, IOException e) {} + + @Override + public void onResponse(Call call, Response response) {} + } + + public static Interceptor userAgent(String userAgent) + { + return chain -> chain.proceed(chain.request().newBuilder() + .header(USER_AGENT, userAgent) + .build()); + } + + public static Interceptor basicAuth(String user, String password) + { + requireNonNull(user, "user is null"); + requireNonNull(password, "password is null"); + if (user.contains(":")) { + throw new ClientException("Illegal character ':' found in username"); + } + + String credential = Credentials.basic(user, password); + return chain -> chain.proceed(chain.request().newBuilder() + .header(AUTHORIZATION, credential) + .build()); + } + + public static void setupTimeouts(OkHttpClient.Builder clientBuilder, int timeout, TimeUnit unit) + { + clientBuilder + .connectTimeout(timeout, unit) + .readTimeout(timeout, unit) + .writeTimeout(timeout, unit); + } + + public static void setupSocksProxy(OkHttpClient.Builder clientBuilder, Optional socksProxy) + { + setupProxy(clientBuilder, socksProxy, SOCKS); + } + + public static void setupHttpProxy(OkHttpClient.Builder clientBuilder, Optional httpProxy) + { + setupProxy(clientBuilder, httpProxy, HTTP); + } + + public static void setupProxy(OkHttpClient.Builder clientBuilder, Optional proxy, Proxy.Type type) + { + proxy.map(OkHttpUtil::toUnresolvedAddress) + .map(address -> new Proxy(type, address)) + .ifPresent(clientBuilder::proxy); + } + + private static InetSocketAddress toUnresolvedAddress(HostAndPort address) + { + return InetSocketAddress.createUnresolved(address.getHost(), address.getPort()); + } + + public static void setupSsl( + OkHttpClient.Builder clientBuilder, + Optional keyStorePath, + Optional keyStorePassword, + Optional trustStorePath, + Optional trustStorePassword) + { + if (!keyStorePath.isPresent() && !trustStorePath.isPresent()) { + return; + } + + try { + // load KeyStore if configured and get KeyManagers + KeyStore keyStore = null; + KeyManager[] keyManagers = null; + if (keyStorePath.isPresent()) { + char[] keyPassword = keyStorePassword.map(String::toCharArray).orElse(null); + + keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + try (InputStream in = new FileInputStream(keyStorePath.get())) { + keyStore.load(in, keyPassword); + } + + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, keyPassword); + keyManagers = keyManagerFactory.getKeyManagers(); + } + + // load TrustStore if configured, otherwise use KeyStore + KeyStore trustStore = keyStore; + if (trustStorePath.isPresent()) { + trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + try (InputStream in = new FileInputStream(trustStorePath.get())) { + trustStore.load(in, trustStorePassword.map(String::toCharArray).orElse(null)); + } + } + + // create TrustManagerFactory + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(trustStore); + + // get X509TrustManager + TrustManager[] trustManagers = trustManagerFactory.getTrustManagers(); + if ((trustManagers.length != 1) || !(trustManagers[0] instanceof X509TrustManager)) { + throw new RuntimeException("Unexpected default trust managers:" + Arrays.toString(trustManagers)); + } + X509TrustManager trustManager = (X509TrustManager) trustManagers[0]; + + // create SSLContext + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagers, new TrustManager[] {trustManager}, null); + + clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustManager); + } + catch (GeneralSecurityException | IOException e) { + throw new ClientException("Error setting up SSL: " + e.getMessage(), e); + } + } + + public static void setupKerberos( + OkHttpClient.Builder clientBuilder, + String remoteServiceName, + boolean useCanonicalHostname, + Optional principal, + Optional kerberosConfig, + Optional keytab, + Optional credentialCache) + { + SpnegoHandler handler = new SpnegoHandler( + remoteServiceName, + useCanonicalHostname, + principal, + kerberosConfig, + keytab, + credentialCache); + clientBuilder.addInterceptor(handler); + clientBuilder.authenticator(handler); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java b/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java index 01207a920709d..30b039215e3fd 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java +++ b/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java @@ -15,10 +15,12 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import io.airlift.units.Duration; import javax.annotation.concurrent.Immutable; import java.util.Objects; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; @@ -30,15 +32,20 @@ public class ServerInfo private final String environment; private final boolean coordinator; + // optional to maintain compatibility with older servers + private final Optional uptime; + @JsonCreator public ServerInfo( @JsonProperty("nodeVersion") NodeVersion nodeVersion, @JsonProperty("environment") String environment, - @JsonProperty("coordinator") boolean coordinator) + @JsonProperty("coordinator") boolean coordinator, + @JsonProperty("uptime") Optional uptime) { this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); this.environment = requireNonNull(environment, "environment is null"); this.coordinator = requireNonNull(coordinator, "coordinator is null"); + this.uptime = requireNonNull(uptime, "uptime is null"); } @JsonProperty @@ -59,6 +66,12 @@ public boolean isCoordinator() return coordinator; } + @JsonProperty + public Optional getUptime() + { + return uptime; + } + @Override public boolean equals(Object o) { @@ -87,6 +100,8 @@ public String toString() .add("nodeVersion", nodeVersion) .add("environment", environment) .add("coordinator", coordinator) + .add("uptime", uptime.orElse(null)) + .omitNullValues() .toString(); } } diff --git a/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java b/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java new file mode 100644 index 0000000000000..cb08cb598392c --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java @@ -0,0 +1,332 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableMap; +import com.sun.security.auth.module.Krb5LoginModule; +import io.airlift.units.Duration; +import okhttp3.Authenticator; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.Route; +import org.ietf.jgss.GSSContext; +import org.ietf.jgss.GSSCredential; +import org.ietf.jgss.GSSException; +import org.ietf.jgss.GSSManager; +import org.ietf.jgss.Oid; + +import javax.annotation.concurrent.GuardedBy; +import javax.security.auth.Subject; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.security.Principal; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.Base64; +import java.util.Locale; +import java.util.Optional; + +import static com.google.common.base.CharMatcher.whitespace; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; +import static java.lang.Boolean.getBoolean; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; +import static javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED; +import static org.ietf.jgss.GSSContext.INDEFINITE_LIFETIME; +import static org.ietf.jgss.GSSCredential.DEFAULT_LIFETIME; +import static org.ietf.jgss.GSSCredential.INITIATE_ONLY; +import static org.ietf.jgss.GSSName.NT_HOSTBASED_SERVICE; +import static org.ietf.jgss.GSSName.NT_USER_NAME; + +// TODO: This class is similar to SpnegoAuthentication in Airlift. Consider extracting a library. +public class SpnegoHandler + implements Interceptor, Authenticator +{ + private static final String NEGOTIATE = "Negotiate"; + private static final Duration MIN_CREDENTIAL_LIFETIME = new Duration(60, SECONDS); + + private static final GSSManager GSS_MANAGER = GSSManager.getInstance(); + + private static final Oid SPNEGO_OID = createOid("1.3.6.1.5.5.2"); + private static final Oid KERBEROS_OID = createOid("1.2.840.113554.1.2.2"); + + private final String remoteServiceName; + private final boolean useCanonicalHostname; + private final Optional principal; + private final Optional keytab; + private final Optional credentialCache; + + @GuardedBy("this") + private Session clientSession; + + public SpnegoHandler( + String remoteServiceName, + boolean useCanonicalHostname, + Optional principal, + Optional kerberosConfig, + Optional keytab, + Optional credentialCache) + { + this.remoteServiceName = requireNonNull(remoteServiceName, "remoteServiceName is null"); + this.useCanonicalHostname = useCanonicalHostname; + this.principal = requireNonNull(principal, "principal is null"); + this.keytab = requireNonNull(keytab, "keytab is null"); + this.credentialCache = requireNonNull(credentialCache, "credentialCache is null"); + + kerberosConfig.ifPresent(file -> System.setProperty("java.security.krb5.conf", file.getAbsolutePath())); + } + + @Override + public Response intercept(Chain chain) + throws IOException + { + // eagerly send authentication if possible + try { + return chain.proceed(authenticate(chain.request())); + } + catch (ClientException ignored) { + return chain.proceed(chain.request()); + } + } + + @Override + public Request authenticate(Route route, Response response) + throws IOException + { + // skip if we already tried or were not asked for Kerberos + if (response.request().headers(AUTHORIZATION).stream().anyMatch(SpnegoHandler::isNegotiate) || + response.headers(WWW_AUTHENTICATE).stream().noneMatch(SpnegoHandler::isNegotiate)) { + return null; + } + + return authenticate(response.request()); + } + + private static boolean isNegotiate(String value) + { + return Splitter.on(whitespace()).split(value).iterator().next().equalsIgnoreCase(NEGOTIATE); + } + + private Request authenticate(Request request) + { + String hostName = request.url().host(); + String principal = makeServicePrincipal(remoteServiceName, hostName, useCanonicalHostname); + byte[] token = generateToken(principal); + + String credential = format("%s %s", NEGOTIATE, Base64.getEncoder().encodeToString(token)); + return request.newBuilder() + .header(AUTHORIZATION, credential) + .build(); + } + + private byte[] generateToken(String servicePrincipal) + { + GSSContext context = null; + try { + Session session = getSession(); + context = doAs(session.getLoginContext().getSubject(), () -> { + GSSContext result = GSS_MANAGER.createContext( + GSS_MANAGER.createName(servicePrincipal, NT_HOSTBASED_SERVICE), + SPNEGO_OID, + session.getClientCredential(), + INDEFINITE_LIFETIME); + + result.requestMutualAuth(true); + result.requestConf(true); + result.requestInteg(true); + result.requestCredDeleg(false); + return result; + }); + + byte[] token = context.initSecContext(new byte[0], 0, 0); + if (token == null) { + throw new LoginException("No token generated from GSS context"); + } + return token; + } + catch (GSSException | LoginException e) { + throw new ClientException(format("Kerberos error for [%s]: %s", servicePrincipal, e.getMessage()), e); + } + finally { + try { + if (context != null) { + context.dispose(); + } + } + catch (GSSException ignored) { + } + } + } + + private synchronized Session getSession() + throws LoginException, GSSException + { + if ((clientSession == null) || clientSession.needsRefresh()) { + clientSession = createSession(); + } + return clientSession; + } + + private Session createSession() + throws LoginException, GSSException + { + // TODO: do we need to call logout() on the LoginContext? + + LoginContext loginContext = new LoginContext("", null, null, new Configuration() + { + @Override + public AppConfigurationEntry[] getAppConfigurationEntry(String name) + { + ImmutableMap.Builder options = ImmutableMap.builder(); + options.put("refreshKrb5Config", "true"); + options.put("doNotPrompt", "true"); + options.put("useKeyTab", "true"); + + if (getBoolean("presto.client.debugKerberos")) { + options.put("debug", "true"); + } + + keytab.ifPresent(file -> options.put("keyTab", file.getAbsolutePath())); + + credentialCache.ifPresent(file -> { + options.put("ticketCache", file.getAbsolutePath()); + options.put("useTicketCache", "true"); + options.put("renewTGT", "true"); + }); + + principal.ifPresent(value -> options.put("principal", value)); + + return new AppConfigurationEntry[] { + new AppConfigurationEntry(Krb5LoginModule.class.getName(), REQUIRED, options.build()) + }; + } + }); + + loginContext.login(); + Subject subject = loginContext.getSubject(); + Principal clientPrincipal = subject.getPrincipals().iterator().next(); + GSSCredential clientCredential = doAs(subject, () -> GSS_MANAGER.createCredential( + GSS_MANAGER.createName(clientPrincipal.getName(), NT_USER_NAME), + DEFAULT_LIFETIME, + KERBEROS_OID, + INITIATE_ONLY)); + + return new Session(loginContext, clientCredential); + } + + private static String makeServicePrincipal(String serviceName, String hostName, boolean useCanonicalHostname) + { + String serviceHostName = hostName; + if (useCanonicalHostname) { + serviceHostName = canonicalizeServiceHostName(hostName); + } + return format("%s@%s", serviceName, serviceHostName.toLowerCase(Locale.US)); + } + + private static String canonicalizeServiceHostName(String hostName) + { + try { + InetAddress address = InetAddress.getByName(hostName); + String fullHostName; + if ("localhost".equalsIgnoreCase(address.getHostName())) { + fullHostName = InetAddress.getLocalHost().getCanonicalHostName(); + } + else { + fullHostName = address.getCanonicalHostName(); + } + if (fullHostName.equalsIgnoreCase("localhost")) { + throw new ClientException("Fully qualified name of localhost should not resolve to 'localhost'. System configuration error?"); + } + return fullHostName; + } + catch (UnknownHostException e) { + throw new ClientException("Failed to resolve host: " + hostName, e); + } + } + + private interface GssSupplier + { + T get() + throws GSSException; + } + + private static T doAs(Subject subject, GssSupplier action) + throws GSSException + { + try { + return Subject.doAs(subject, (PrivilegedExceptionAction) action::get); + } + catch (PrivilegedActionException e) { + Throwable t = e.getCause(); + throwIfInstanceOf(t, GSSException.class); + throwIfUnchecked(t); + throw new RuntimeException(t); + } + } + + private static Oid createOid(String value) + { + try { + return new Oid(value); + } + catch (GSSException e) { + throw new AssertionError(e); + } + } + + private static class Session + { + private final LoginContext loginContext; + private final GSSCredential clientCredential; + + public Session(LoginContext loginContext, GSSCredential clientCredential) + throws LoginException + { + requireNonNull(loginContext, "loginContext is null"); + requireNonNull(clientCredential, "gssCredential is null"); + + this.loginContext = loginContext; + this.clientCredential = clientCredential; + } + + public LoginContext getLoginContext() + { + return loginContext; + } + + public GSSCredential getClientCredential() + { + return clientCredential; + } + + public boolean needsRefresh() + throws GSSException + { + return clientCredential.getRemainingLifetime() < MIN_CREDENTIAL_LIFETIME.getValue(SECONDS); + } + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java b/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java index 7373f17338af9..785a73b9ae38c 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java +++ b/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java @@ -13,19 +13,19 @@ */ package com.facebook.presto.client; +import com.facebook.presto.client.OkHttpUtil.NullCallback; import com.facebook.presto.spi.type.TimeZoneKey; import com.google.common.base.Splitter; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import io.airlift.http.client.FullJsonResponseHandler; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClient.HttpResponseFuture; -import io.airlift.http.client.HttpStatus; -import io.airlift.http.client.Request; import io.airlift.json.JsonCodec; -import io.airlift.units.Duration; +import okhttp3.Headers; +import okhttp3.HttpUrl; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; import javax.annotation.concurrent.ThreadSafe; @@ -37,35 +37,36 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.client.PrestoHeaders.PRESTO_ADDED_PREPARE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLEAR_SESSION; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLEAR_TRANSACTION_ID; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO; import static com.facebook.presto.client.PrestoHeaders.PRESTO_DEALLOCATED_PREPARE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_LANGUAGE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_PREPARED_STATEMENT; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_SCHEMA; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_SESSION; import static com.facebook.presto.client.PrestoHeaders.PRESTO_SET_SESSION; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_SOURCE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_STARTED_TRANSACTION_ID; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_TIME_ZONE; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRANSACTION_ID; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.net.HttpHeaders.USER_AGENT; -import static io.airlift.http.client.FullJsonResponseHandler.JsonResponse; -import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; -import static io.airlift.http.client.HttpStatus.Family; -import static io.airlift.http.client.HttpStatus.familyForStatusCode; -import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static io.airlift.http.client.Request.Builder.prepareDelete; -import static io.airlift.http.client.Request.Builder.prepareGet; -import static io.airlift.http.client.Request.Builder.preparePost; -import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; -import static io.airlift.http.client.StatusResponseHandler.StatusResponse; -import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; +import static io.airlift.json.JsonCodec.jsonCodec; import static java.lang.String.format; -import static java.nio.charset.StandardCharsets.UTF_8; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; +import static java.net.HttpURLConnection.HTTP_UNAVAILABLE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -74,13 +75,15 @@ public class StatementClient implements Closeable { + private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("application/json; charset=utf-8"); + private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); + private static final Splitter SESSION_HEADER_SPLITTER = Splitter.on('=').limit(2).trimResults(); private static final String USER_AGENT_VALUE = StatementClient.class.getSimpleName() + "/" + firstNonNull(StatementClient.class.getPackage().getImplementationVersion(), "unknown"); - private final HttpClient httpClient; - private final FullJsonResponseHandler responseHandler; + private final OkHttpClient httpClient; private final boolean debug; private final String query; private final AtomicReference currentResults = new AtomicReference<>(); @@ -97,15 +100,13 @@ public class StatementClient private final long requestTimeoutNanos; private final String user; - public StatementClient(HttpClient httpClient, JsonCodec queryResultsCodec, ClientSession session, String query) + public StatementClient(OkHttpClient httpClient, ClientSession session, String query) { requireNonNull(httpClient, "httpClient is null"); - requireNonNull(queryResultsCodec, "queryResultsCodec is null"); requireNonNull(session, "session is null"); requireNonNull(query, "query is null"); this.httpClient = httpClient; - this.responseHandler = createFullJsonResponseHandler(queryResultsCodec); this.debug = session.isDebug(); this.timeZone = session.getTimeZone(); this.query = query; @@ -113,48 +114,54 @@ public StatementClient(HttpClient httpClient, JsonCodec queryResul this.user = session.getUser(); Request request = buildQueryRequest(session, query); - JsonResponse response = httpClient.execute(request, responseHandler); - if (response.getStatusCode() != HttpStatus.OK.code() || !response.hasValue()) { + JsonResponse response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpClient, request); + if ((response.getStatusCode() != HTTP_OK) || !response.hasValue()) { throw requestFailedException("starting query", request, response); } - processResponse(response); + processResponse(response.getHeaders(), response.getValue()); } private Request buildQueryRequest(ClientSession session, String query) { - Request.Builder builder = prepareRequest(preparePost(), uriBuilderFrom(session.getServer()).replacePath("/v1/statement").build()) - .setBodyGenerator(createStaticBodyGenerator(query, UTF_8)); + HttpUrl url = HttpUrl.get(session.getServer()); + if (url == null) { + throw new ClientException("Invalid server URL: " + session.getServer()); + } + url = url.newBuilder().encodedPath("/v1/statement").build(); + + Request.Builder builder = prepareRequest(url) + .post(RequestBody.create(MEDIA_TYPE_JSON, query)); if (session.getSource() != null) { - builder.setHeader(PrestoHeaders.PRESTO_SOURCE, session.getSource()); + builder.addHeader(PRESTO_SOURCE, session.getSource()); } if (session.getClientInfo() != null) { - builder.setHeader(PrestoHeaders.PRESTO_CLIENT_INFO, session.getClientInfo()); + builder.addHeader(PRESTO_CLIENT_INFO, session.getClientInfo()); } if (session.getCatalog() != null) { - builder.setHeader(PrestoHeaders.PRESTO_CATALOG, session.getCatalog()); + builder.addHeader(PRESTO_CATALOG, session.getCatalog()); } if (session.getSchema() != null) { - builder.setHeader(PrestoHeaders.PRESTO_SCHEMA, session.getSchema()); + builder.addHeader(PRESTO_SCHEMA, session.getSchema()); } - builder.setHeader(PrestoHeaders.PRESTO_TIME_ZONE, session.getTimeZone().getId()); + builder.addHeader(PRESTO_TIME_ZONE, session.getTimeZone().getId()); if (session.getLocale() != null) { - builder.setHeader(PrestoHeaders.PRESTO_LANGUAGE, session.getLocale().toLanguageTag()); + builder.addHeader(PRESTO_LANGUAGE, session.getLocale().toLanguageTag()); } Map property = session.getProperties(); for (Entry entry : property.entrySet()) { - builder.addHeader(PrestoHeaders.PRESTO_SESSION, entry.getKey() + "=" + entry.getValue()); + builder.addHeader(PRESTO_SESSION, entry.getKey() + "=" + entry.getValue()); } Map statements = session.getPreparedStatements(); for (Entry entry : statements.entrySet()) { - builder.addHeader(PrestoHeaders.PRESTO_PREPARED_STATEMENT, urlEncode(entry.getKey()) + "=" + urlEncode(entry.getValue())); + builder.addHeader(PRESTO_PREPARED_STATEMENT, urlEncode(entry.getKey()) + "=" + urlEncode(entry.getValue())); } - builder.setHeader(PrestoHeaders.PRESTO_TRANSACTION_ID, session.getTransactionId() == null ? "NONE" : session.getTransactionId()); + builder.addHeader(PRESTO_TRANSACTION_ID, session.getTransactionId() == null ? "NONE" : session.getTransactionId()); return builder.build(); } @@ -241,13 +248,12 @@ public boolean isValid() return valid.get() && (!isGone()) && (!isClosed()); } - private Request.Builder prepareRequest(Request.Builder builder, URI nextUri) + private Request.Builder prepareRequest(HttpUrl url) { - builder.setHeader(PrestoHeaders.PRESTO_USER, user); - builder.setHeader(USER_AGENT, USER_AGENT_VALUE) - .setUri(nextUri); - - return builder; + return new Request.Builder() + .addHeader(PRESTO_USER, user) + .addHeader(USER_AGENT, USER_AGENT_VALUE) + .url(url); } public boolean advance() @@ -258,7 +264,7 @@ public boolean advance() return false; } - Request request = prepareRequest(prepareGet(), nextUri).build(); + Request request = prepareRequest(HttpUrl.get(nextUri)).build(); Exception cause = null; long start = System.nanoTime(); @@ -284,19 +290,19 @@ public boolean advance() JsonResponse response; try { - response = httpClient.execute(request, responseHandler); + response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpClient, request); } catch (RuntimeException e) { cause = e; continue; } - if (response.getStatusCode() == HttpStatus.OK.code() && response.hasValue()) { - processResponse(response); + if ((response.getStatusCode() == HTTP_OK) && response.hasValue()) { + processResponse(response.getHeaders(), response.getValue()); return true; } - if (response.getStatusCode() != HttpStatus.SERVICE_UNAVAILABLE.code()) { + if (response.getStatusCode() != HTTP_UNAVAILABLE) { throw requestFailedException("fetching next", request, response); } } @@ -306,77 +312,65 @@ public boolean advance() throw new RuntimeException("Error fetching next", cause); } - private void processResponse(JsonResponse response) + private void processResponse(Headers headers, QueryResults results) { - for (String setSession : response.getHeaders(PRESTO_SET_SESSION)) { + for (String setSession : headers.values(PRESTO_SET_SESSION)) { List keyValue = SESSION_HEADER_SPLITTER.splitToList(setSession); if (keyValue.size() != 2) { continue; } setSessionProperties.put(keyValue.get(0), keyValue.size() > 1 ? keyValue.get(1) : ""); } - for (String clearSession : response.getHeaders(PRESTO_CLEAR_SESSION)) { + for (String clearSession : headers.values(PRESTO_CLEAR_SESSION)) { resetSessionProperties.add(clearSession); } - for (String entry : response.getHeaders(PRESTO_ADDED_PREPARE)) { + for (String entry : headers.values(PRESTO_ADDED_PREPARE)) { List keyValue = SESSION_HEADER_SPLITTER.splitToList(entry); if (keyValue.size() != 2) { continue; } addedPreparedStatements.put(urlDecode(keyValue.get(0)), urlDecode(keyValue.get(1))); } - for (String entry : response.getHeaders(PRESTO_DEALLOCATED_PREPARE)) { + for (String entry : headers.values(PRESTO_DEALLOCATED_PREPARE)) { deallocatedPreparedStatements.add(urlDecode(entry)); } - String startedTransactionId = response.getHeader(PRESTO_STARTED_TRANSACTION_ID); + String startedTransactionId = headers.get(PRESTO_STARTED_TRANSACTION_ID); if (startedTransactionId != null) { this.startedtransactionId.set(startedTransactionId); } - if (response.getHeader(PRESTO_CLEAR_TRANSACTION_ID) != null) { + if (headers.values(PRESTO_CLEAR_TRANSACTION_ID) != null) { clearTransactionId.set(true); } - currentResults.set(response.getValue()); + currentResults.set(results); } private RuntimeException requestFailedException(String task, Request request, JsonResponse response) { gone.set(true); if (!response.hasValue()) { + if (response.getStatusCode() == HTTP_UNAUTHORIZED) { + return new ClientException("Authentication failed" + + Optional.ofNullable(response.getStatusMessage()) + .map(message -> ": " + message) + .orElse("")); + } return new RuntimeException( - format("Error %s at %s returned an invalid response: %s [Error: %s]", task, request.getUri(), response, response.getResponseBody()), + format("Error %s at %s returned an invalid response: %s [Error: %s]", task, request.url(), response, response.getResponseBody()), response.getException()); } - return new RuntimeException(format("Error %s at %s returned %s: %s", task, request.getUri(), response.getStatusCode(), response.getStatusMessage())); + return new RuntimeException(format("Error %s at %s returned HTTP %s", task, request.url(), response.getStatusCode())); } - public boolean cancelLeafStage(Duration timeout) + public void cancelLeafStage() { checkState(!isClosed(), "client is closed"); URI uri = current().getPartialCancelUri(); - if (uri == null) { - return false; - } - - Request request = prepareRequest(prepareDelete(), uri).build(); - - HttpResponseFuture response = httpClient.executeAsync(request, createStatusResponseHandler()); - try { - StatusResponse status = response.get(timeout.toMillis(), MILLISECONDS); - return familyForStatusCode(status.getStatusCode()) == Family.SUCCESSFUL; - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw Throwables.propagate(e); - } - catch (ExecutionException e) { - throw Throwables.propagate(e.getCause()); - } - catch (TimeoutException e) { - return false; + if (uri != null) { + httpDelete(uri); } } @@ -386,12 +380,19 @@ public void close() if (!closed.getAndSet(true)) { URI uri = currentResults.get().getNextUri(); if (uri != null) { - Request request = prepareRequest(prepareDelete(), uri).build(); - httpClient.executeAsync(request, createStatusResponseHandler()); + httpDelete(uri); } } } + private void httpDelete(URI uri) + { + Request request = prepareRequest(HttpUrl.get(uri)) + .delete() + .build(); + httpClient.newCall(request).enqueue(new NullCallback()); + } + private static String urlEncode(String value) { try { diff --git a/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java b/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java new file mode 100644 index 0000000000000..036851d4252a8 --- /dev/null +++ b/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client; + +import io.airlift.json.JsonCodec; +import io.airlift.units.Duration; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.client.NodeVersion.UNKNOWN; +import static io.airlift.json.JsonCodec.jsonCodec; +import static org.testng.Assert.assertEquals; + +public class TestServerInfo +{ + private static final JsonCodec SERVER_INFO_CODEC = jsonCodec(ServerInfo.class); + + @Test + public void testJsonRoundTrip() + { + assertJsonRoundTrip(new ServerInfo(UNKNOWN, "test", true, Optional.of(Duration.valueOf("2m")))); + assertJsonRoundTrip(new ServerInfo(UNKNOWN, "test", true, Optional.empty())); + } + + @Test + public void testBackwardsCompatible() + { + ServerInfo newServerInfo = new ServerInfo(UNKNOWN, "test", true, Optional.empty()); + ServerInfo legacyServerInfo = SERVER_INFO_CODEC.fromJson("{\"nodeVersion\":{\"version\":\"\"},\"environment\":\"test\",\"coordinator\":true}"); + assertEquals(newServerInfo, legacyServerInfo); + } + + private static void assertJsonRoundTrip(ServerInfo serverInfo) + { + String json = SERVER_INFO_CODEC.toJson(serverInfo); + ServerInfo copy = SERVER_INFO_CODEC.fromJson(json); + assertEquals(copy, serverInfo); + } +} diff --git a/presto-docs/pom.xml b/presto-docs/pom.xml index 36d725071547c..ae09ccdb7e089 100644 --- a/presto-docs/pom.xml +++ b/presto-docs/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-docs @@ -44,6 +44,38 @@ + + org.codehaus.mojo + exec-maven-plugin + + + validate-reserved + validate + + java + + + com.facebook.presto.sql.ReservedIdentifiers + + validateDocs + ${project.basedir}/src/main/sphinx/language/reserved.rst + + + + + + false + true + + + + com.facebook.presto + presto-parser + ${project.version} + + + + io.airlift.maven.plugins sphinx-maven-plugin diff --git a/presto-docs/src/main/sphinx/admin.rst b/presto-docs/src/main/sphinx/admin.rst index eab299a0a3925..0af2c6e6a3c0b 100644 --- a/presto-docs/src/main/sphinx/admin.rst +++ b/presto-docs/src/main/sphinx/admin.rst @@ -7,5 +7,6 @@ Administration admin/web-interface admin/tuning + admin/properties admin/queue admin/resource-groups diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst new file mode 100644 index 0000000000000..6284c52f07bf1 --- /dev/null +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -0,0 +1,416 @@ +==================== +Properties Reference +==================== + +This section describes the most important config properties that +may be used to tune Presto or alter its behavior when required. + +.. contents:: + :local: + :backlinks: none + :depth: 1 + +General Properties +------------------ + +``distributed-joins-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``true`` + + Use hash distributed joins instead of broadcast joins. Distributed joins + require redistributing both tables using a hash of the join key. This can + be slower (sometimes substantially) than broadcast joins, but allows much + larger joins. Broadcast joins require that the tables on the right side of + the join after filtering fit in memory on each node, whereas distributed joins + only need to fit in distributed memory across all nodes. This can also be + specified on a per-query basis using the ``distributed_join`` session property. + +``redistribute-writes`` +^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``true`` + + This property enables redistribution of data before writing. This can + eliminate the performance impact of data skew when writing by hashing it + across nodes in the cluster. It can be disabled when it is known that the + output data set is not skewed in order to avoid the overhead of hashing and + redistributing all the data across the network. This can also be specified + on a per-query basis using the ``redistribute_writes`` session property. + +``resources.reserved-system-memory`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``data size`` + * **Default value:** ``JVM max memory * 0.4`` + + The amount of JVM memory reserved, for accounting purposes, for things + that are not directly attributable to or controllable by a user query. + For example, output buffers, code caches, etc. This also accounts for + memory that is not tracked by the memory tracking system. + + The purpose of this property is to prevent the JVM from running out of + memory (OOM). The default value is suitable for smaller JVM heap sizes or + clusters with many concurrent queries. If running fewer queries with a + large heap, a smaller value may work. Basically, set this value large + enough that the JVM does not fail with ``OutOfMemoryError``. + + +Exchange Properties +------------------- + +Exchanges transfer data between Presto nodes for different stages of +a query. Adjusting these properties may help to resolve inter-node +communication issues or improve network utilization. + +``exchange.client-threads`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``1`` + * **Default value:** ``25`` + + Number of threads used by exchange clients to fetch data from other Presto + nodes. A higher value can improve performance for large clusters or clusters + with very high concurrency, but excessively high values may cause a drop + in performance due to context switches and additional memory usage. + +``exchange.concurrent-request-multiplier`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``1`` + * **Default value:** ``3`` + + Multiplier determining the number of concurrent requests relative to + available buffer memory. The maximum number of requests is determined + using a heuristic of the number of clients that can fit into available + buffer space based on average buffer usage per request times this + multiplier. For example, with an ``exchange.max-buffer-size`` of ``32 MB`` + and ``20 MB`` already used and average size per request being ``2MB``, + the maximum number of clients is + ``multiplier * ((32MB - 20MB) / 2MB) = multiplier * 6``. Tuning this + value adjusts the heuristic, which may increase concurrency and improve + network utilization. + +``exchange.max-buffer-size`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``data size`` + * **Default value:** ``32MB`` + + Size of buffer in the exchange client that holds data fetched from other + nodes before it is processed. A larger buffer can increase network + throughput for larger clusters and thus decrease query processing time, + but will reduce the amount of memory available for other usages. + +``exchange.max-response-size`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``data size`` + * **Minimum value:** ``1MB`` + * **Default value:** ``16MB`` + + Maximum size of a response returned from an exchange request. The response + will be placed in the exchange client buffer which is shared across all + concurrent requests for the exchange. + + Increasing the value may improve network throughput if there is high + latency. Decreasing the value may improve query performance for large + clusters as it reduces skew due to the exchange client buffer holding + responses for more tasks (rather than hold more data from fewer tasks). + +``sink.max-buffer-size`` +^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``data size`` + * **Default value:** ``32MB`` + + Output buffer size for task data that is waiting to be pulled by upstream + tasks. If the task output is hash partitioned, then the buffer will be + shared across all of the partitioned consumers. Increasing this value may + improve network throughput for data transferred between stages if the + network has high latency or if there are many nodes in the cluster. + + +Task Properties +--------------- + +``task.concurrency`` +^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Restrictions:** must be a power of two + * **Default value:** ``16`` + + Default local concurrency for parallel operators such as joins and aggregations. + This value should be adjusted up or down based on the query concurrency and worker + resource utilization. Lower values are better for clusters that run many queries + concurrently because the cluster will already be utilized by all the running + queries, so adding more concurrency will result in slow downs due to context + switching and other overhead. Higher values are better for clusters that only run + one or a few queries at a time. This can also be specified on a per-query basis + using the ``task_concurrency`` session property. + +``task.http-response-threads`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``1`` + * **Default value:** ``100`` + + Maximum number of threads that may be created to handle HTTP responses. Threads are + created on demand and are cleaned up when idle, thus there is no overhead to a large + value if the number of requests to be handled is small. More threads may be helpful + on clusters with a high number of concurrent queries, or on clusters with hundreds + or thousands of workers. + +``task.http-timeout-threads`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``1`` + * **Default value:** ``3`` + + Number of threads used to handle timeouts when generating HTTP responses. This value + should be increased if all the threads are frequently in use. This can be monitored + via the ``com.facebook.presto.server:name=AsyncHttpExecutionMBean:TimeoutExecutor`` + JMX object. If ``ActiveCount`` is always the same as ``PoolSize``, increase the + number of threads. + +``task.info-update-interval`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``duration`` + * **Minimum value:** ``1ms`` + * **Maximum value:** ``10s`` + * **Default value:** ``3s`` + + Controls staleness of task information, which is used in scheduling. Larger values + can reduce coordinator CPU load, but may result in suboptimal split scheduling. + +``task.max-partial-aggregation-memory`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``data size`` + * **Default value:** ``16MB`` + + Maximum size of partial aggregation results for distributed aggregations. Increasing this + value can result in less network transfer and lower CPU utilization by allowing more + groups to be kept locally before being flushed, at the cost of additional memory usage. + +``task.max-worker-threads`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Default value:** ``Node CPUs * 2`` + + Sets the number of threads used by workers to process splits. Increasing this number + can improve throughput if worker CPU utilization is low and all the threads are in use, + but will cause increased heap space usage. Setting the value too high may cause a drop + in performance due to a context switching. The number of active threads is available + via the ``RunningSplits`` property of the + ``com.facebook.presto.execution.executor:name=TaskExecutor.RunningSplits`` JXM object. + +``task.min-drivers`` +^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Default value:** ``task.max-worker-threads * 2`` + + The target number of running leaf splits on a worker. This is a minimum value because + each leaf task is guaranteed at least ``3`` running splits. Non-leaf tasks are also + guaranteed to run in order to prevent deadlocks. A lower value may improve responsiveness + for new tasks, but can result in underutilized resources. A higher value can increase + resource utilization, but uses additional memory. + +``task.writer-count`` +^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Restrictions:** must be a power of two + * **Default value:** ``1`` + + The number of concurrent writer threads per worker per query. Increasing this value may + increase write speed, especially when a query is not I/O bound and can take advantage + of additional CPU for parallel writes (some connectors can be bottlenecked on CPU when + writing due to compression or other factors). Setting this too high may cause the cluster + to become overloaded due to excessive resource utilization. This can also be specified on + a per-query basis using the ``task_writer_count`` session property. + + +Node Scheduler Properties +------------------------- + +``node-scheduler.max-splits-per-node`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Default value:** ``100`` + + The target value for the total number of splits that can be running for + each worker node. + + Using a higher value is recommended if queries are submitted in large batches + (e.g., running a large group of reports periodically) or for connectors that + produce many splits that complete quickly. Increasing this value may improve + query latency by ensuring that the workers have enough splits to keep them + fully utilized. + + Setting this too high will waste memory and may result in lower performance + due to splits not being balanced across workers. Ideally, it should be set + such that there is always at least one split waiting to be processed, but + not higher. + +``node-scheduler.max-pending-splits-per-task`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Default value:** ``10`` + + The number of outstanding splits that can be queued for each worker node + for a single stage of a query, even when the node is already at the limit for + total number of splits. Allowing a minimum number of splits per stage is + required to prevent starvation and deadlocks. + + This value must be smaller than ``node-scheduler.max-splits-per-node``, + will usually be increased for the same reasons, and has similar drawbacks + if set too high. + +``node-scheduler.min-candidates`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``1`` + * **Default value:** ``10`` + + The minimum number of candidate nodes that will be evaluated by the + node scheduler when choosing the target node for a split. Setting + this value too low may prevent splits from being properly balanced + across all worker nodes. Setting it too high may increase query + latency and increase CPU usage on the coordinator. + +``node-scheduler.network-topology`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``string`` + * **Allowed values:** ``legacy``, ``flat`` + * **Default value:** ``legacy`` + + +Optimizer Properties +-------------------- + +``optimizer.dictionary-aggregation`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``false`` + + Enables optimization for aggregations on dictionaries. This can also be specified + on a per-query basis using the ``dictionary_aggregation`` session property. + +``optimizer.optimize-hash-generation`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``true`` + + Compute hash codes for distribution, joins, and aggregations early during execution, + allowing result to be shared between operations later in the query. This can reduce + CPU usage by avoiding computing the same hash multiple times, but at the cost of + additional network transfer for the hashes. In most cases it will decrease overall + query processing time. This can also be specified on a per-query basis using the + ``optimize_hash_generation`` session property. + + It is often helpful to disable this property when using :doc:`/sql/explain` in order + to make the query plan easier to read. + +``optimizer.optimize-metadata-queries`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``false`` + + Enable optimization of some aggregations by using values that are stored as metadata. + This allows Presto to execute some simple queries in constant time. Currently, this + optimization applies to ``max``, ``min`` and ``approx_distinct`` of partition + keys and other aggregation insensitive to the cardinality of the input (including + ``DISTINCT`` aggregates). Using this may speed up some queries significantly. + + The main drawback is that it can produce incorrect results if the connector returns + partition keys for partitions that have no rows. In particular, the Hive connector + can return empty partitions if they were created by other systems (Presto cannot + create them). + +``optimizer.optimize-single-distinct`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``true`` + + The single distinct optimization will try to replace multiple ``DISTINCT`` clauses + with a single ``GROUP BY`` clause, which can be substantially faster to execute. + +``optimizer.push-table-write-through-union`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``boolean`` + * **Default value:** ``true`` + + Parallelize writes when using ``UNION ALL`` in queries that write data. This improves the + speed of writing output tables in ``UNION ALL`` queries because these writes do not require + additional synchronization when collecting results. Enabling this optimization can improve + ``UNION ALL`` speed when write speed is not yet saturated. However, it may slow down queries + in an already heavily loaded system. This can also be specified on a per-query basis + using the ``push_table_write_through_union`` session property. + + +Regular Expression Function Properties +-------------------------------------- + +The following properties allow tuning the :doc:`/functions/regexp`. + +``regex-library`` +^^^^^^^^^^^^^^^^^ + + * **Type:** ``string`` + * **Allowed values:** ``JONI``, ``RE2J`` + * **Default value:** ``JONI`` + + Which library to use for regular expression functions. + ``JONI`` is generally faster for common usage, but can require exponential + time for certain expression patterns. ``RE2J`` uses a different algorithm + which guarantees linear time, but is often slower. + +``re2j.dfa-states-limit`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``2`` + * **Default value:** ``2147483647`` + + The maximum number of states to use when RE2J builds the fast + but potentially memory intensive deterministic finite automaton (DFA) + for regular expression matching. If the limit is reached, RE2J will fall + back to the algorithm that uses the slower, but less memory intensive + non-deterministic finite automaton (NFA). Decreasing this value decreases the + maximum memory footprint of a regular expression search at the cost of speed. + +``re2j.dfa-retries`` +^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``integer`` + * **Minimum value:** ``0`` + * **Default value:** ``5`` + + The number of times that RE2J will retry the DFA algorithm when + it reaches a states limit before using the slower, but less memory + intensive NFA algorithm for all future inputs for that search. If hitting the + limit for a given input row is likely to be an outlier, you want to be able + to process subsequent rows using the faster DFA algorithm. If you are likely + to hit the limit on matches for subsequent rows as well, you want to use the + correct algorithm from the beginning so as not to waste time and resources. + The more rows you are processing, the larger this value should be. diff --git a/presto-docs/src/main/sphinx/admin/resource-groups.rst b/presto-docs/src/main/sphinx/admin/resource-groups.rst index 2d61dff52d3ad..f17881b387cb0 100644 --- a/presto-docs/src/main/sphinx/admin/resource-groups.rst +++ b/presto-docs/src/main/sphinx/admin/resource-groups.rst @@ -96,7 +96,7 @@ There are three selectors that define which queries run in which resource group: * The last selector is a catch all, which puts all queries into the user's adhoc group. All together these selectors implement the policy that ``bob`` is an admin and -all other users are subject to the follow limits: +all other users are subject to the following limits: * Users are allowed to have up to 2 adhoc queries running. Additionally, they may run one pipeline. @@ -146,7 +146,7 @@ all other users are subject to the follow limits: { "name": "admin", "softMemoryLimit": "100%", - "maxRunning": 100, + "maxRunning": 200, "maxQueued": 100, "schedulingPolicy": "query_priority", "jmxExport": true diff --git a/presto-docs/src/main/sphinx/admin/tuning.rst b/presto-docs/src/main/sphinx/admin/tuning.rst index 8d79b73ad1fdd..dabf0c25a7a5c 100644 --- a/presto-docs/src/main/sphinx/admin/tuning.rst +++ b/presto-docs/src/main/sphinx/admin/tuning.rst @@ -8,27 +8,7 @@ information may help you if your cluster is facing a specific performance proble Config Properties ----------------- -These configuration options may require tuning in specific situations: - -* ``task.max-worker-threads``: - Sets the number of threads used by workers to process splits. Increasing this number - can improve throughput if worker CPU utilization is low and all the threads are in use, - but will cause increased heap space usage. The number of active threads is available via - the ``com.facebook.presto.execution.executor.TaskExecutor.RunningSplits`` JMX stat. - -* ``distributed-joins-enabled``: - Use hash distributed joins instead of broadcast joins. Distributed joins - require redistributing both tables using a hash of the join key. This can - be slower (sometimes substantially) than broadcast joins, but allows much - larger joins. Broadcast joins require that the tables on the right side of - the join fit in memory on each machine, whereas with distributed joins the - tables on the right side have to fit in distributed memory. This can also be - specified on a per-query basis using the ``distributed_join`` session property. - -* ``node-scheduler.network-topology``: - Sets the network topology to use when scheduling splits. "legacy" will ignore - the topology when scheduling splits. "flat" will try to schedule splits on the same - host as the data is located by reserving 50% of the work queue for local splits. +See :doc:`/admin/properties`. JVM Settings ------------ diff --git a/presto-docs/src/main/sphinx/connector/accumulo.rst b/presto-docs/src/main/sphinx/connector/accumulo.rst index 6c77836abf8c8..49af8f9eec17d 100644 --- a/presto-docs/src/main/sphinx/connector/accumulo.rst +++ b/presto-docs/src/main/sphinx/connector/accumulo.rst @@ -480,20 +480,22 @@ Note that session properties are prefixed with the catalog name:: SET SESSION accumulo.column_filter_optimizations_enabled = false; -======================================== ============= ======================================================================================================= -Property Name Default Value Description -======================================== ============= ======================================================================================================= -``optimize_locality_enabled`` ``true`` Set to true to enable data locality for non-indexed scans -``optimize_split_ranges_enabled`` ``true`` Set to true to split non-indexed queries by tablet splits. Should generally be true. -``optimize_index_enabled`` ``true`` Set to true to enable usage of the secondary index on query -``index_rows_per_split`` ``10000`` The number of Accumulo row IDs that are packed into a single Presto split -``index_threshold`` ``0.2`` The ratio between number of rows to be scanned based on the index over the total number of rows. - If the ratio is below this threshold, the index will be used. -``index_lowest_cardinality_threshold`` ``0.01`` The threshold where the column with the lowest cardinality will be used instead of computing an - intersection of ranges in the index. Secondary index must be enabled. -``index_metrics_enabled`` ``true`` Set to true to enable usage of the metrics table to optimize usage of the index -``scan_username`` (config) User to impersonate when scanning the tables. This property trumps the ``scan_auths`` table property. -======================================== ============= ======================================================================================================= +============================================= ============= ======================================================================================================= +Property Name Default Value Description +============================================= ============= ======================================================================================================= +``optimize_locality_enabled`` ``true`` Set to true to enable data locality for non-indexed scans +``optimize_split_ranges_enabled`` ``true`` Set to true to split non-indexed queries by tablet splits. Should generally be true. +``optimize_index_enabled`` ``true`` Set to true to enable usage of the secondary index on query +``index_rows_per_split`` ``10000`` The number of Accumulo row IDs that are packed into a single Presto split +``index_threshold`` ``0.2`` The ratio between number of rows to be scanned based on the index over the total number of rows + If the ratio is below this threshold, the index will be used. +``index_lowest_cardinality_threshold`` ``0.01`` The threshold where the column with the lowest cardinality will be used instead of computing an + intersection of ranges in the index. Secondary index must be enabled +``index_metrics_enabled`` ``true`` Set to true to enable usage of the metrics table to optimize usage of the index +``scan_username`` (config) User to impersonate when scanning the tables. This property trumps the ``scan_auths`` table property +``index_short_circuit_cardinality_fetch`` ``true`` Short circuit the retrieval of index metrics once any column is less than the lowest cardinality threshold +``index_cardinality_cache_polling_duration`` ``10ms`` Sets the cardinality cache polling duration for short circuit retrieval of index metrics +============================================= ============= ======================================================================================================= Adding Columns -------------- diff --git a/presto-docs/src/main/sphinx/connector/cassandra.rst b/presto-docs/src/main/sphinx/connector/cassandra.rst index 13603af662028..9c93f5bbb4328 100644 --- a/presto-docs/src/main/sphinx/connector/cassandra.rst +++ b/presto-docs/src/main/sphinx/connector/cassandra.rst @@ -49,16 +49,6 @@ Property Name Description ``cassandra.native-protocol-port`` The Cassandra server port running the native client protocol (defaults to ``9042``). -``cassandra.max-schema-refresh-threads`` Maximum number of schema cache refresh threads. This property - corresponds to the maximum number of parallel requests. - -``cassandra.schema-cache-ttl`` Maximum time that information about a schema will be cached - (defaults to ``1h``). - -``cassandra.schema-refresh-interval`` The schema information cache will be refreshed in the background - when accessed if the cached data is at least this old - (defaults to ``2m``). - ``cassandra.consistency-level`` Consistency levels in Cassandra refer to the level of consistency to be used for both read and write operations. More information about consistency levels can be found in the @@ -184,3 +174,59 @@ This table can be described in Presto:: This table can then be queried in Presto:: SELECT * FROM cassandra.mykeyspace.users; + +Data types +---------- + +The data types mappings are as follows: + +================ ====== +Cassandra Presto +================ ====== +ASCII VARCHAR +BIGINT BIGINT +BLOB VARBINARY +BOOLEAN BOOLEAN +DECIMAL DOUBLE +DOUBLE DOUBLE +FLOAT DOUBLE +INET VARCHAR(45) +INT INTEGER +LIST VARCHAR +MAP VARCHAR +SET VARCHAR +TEXT VARCHAR +TIMESTAMP TIMESTAMP +TIMEUUID VARCHAR +VARCHAR VARCHAR +VARIANT VARCHAR +================ ====== + +Any collection (LIST/MAP/SET) can be designated as FROZEN, and the value is +mapped to VARCHAR. Additionally, blobs have the limitation that they cannot be empty. + +Types not mentioned in the table above are not supported (e.g. tuple or UDT). + +Partition keys can only be of the following types: +| ASCII +| TEXT +| VARCHAR +| BIGINT +| BOOLEAN +| DOUBLE +| INET +| INT +| FLOAT +| DECIMAL +| TIMESTAMP +| UUID +| TIMEUUID + +Limitations +----------- + +* Queries without filters containing the partition key result in fetching all partitions. + This causes a full scan of the entire data set, therefore it's much slower compared to a similar + query with a partition key as a filter. +* ``IN`` list filters are only allowed on index (that is, partition key or clustering key) columns. +* Range (``<`` or ``>`` and ``BETWEEN``) filters can be applied only to the partition keys. diff --git a/presto-docs/src/main/sphinx/connector/kafka-tutorial.rst b/presto-docs/src/main/sphinx/connector/kafka-tutorial.rst index d4d0e820e1dd7..f9dcea003b8d1 100644 --- a/presto-docs/src/main/sphinx/connector/kafka-tutorial.rst +++ b/presto-docs/src/main/sphinx/connector/kafka-tutorial.rst @@ -192,8 +192,8 @@ actual table shape. The raw data is available through the ``_message`` and in JSON format, the :doc:`/functions/json` built into Presto can be used to slice the data. -Step 5: Add a topic decription file ------------------------------------ +Step 5: Add a topic description file +------------------------------------ The Kafka connector supports topic description files to turn raw data into table format. These files are located in the ``etc/kafka`` folder in the diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index f949886f77e62..0f6b3320818b6 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -101,6 +101,10 @@ Array Functions See :func:`reduce`. +.. function:: repeat(element, count) -> array + + Repeat ``element`` for ``count`` times. + .. function:: reverse(x) -> array :noindex: diff --git a/presto-docs/src/main/sphinx/functions/binary.rst b/presto-docs/src/main/sphinx/functions/binary.rst index cb8026e962a3e..e0cb5897f5145 100644 --- a/presto-docs/src/main/sphinx/functions/binary.rst +++ b/presto-docs/src/main/sphinx/functions/binary.rst @@ -42,6 +42,19 @@ Binary Functions Decodes ``bigint`` value from a 64-bit 2's complement big endian ``binary``. +.. function:: to_ieee754_32(real) -> varbinary + + Encodes ``real`` in a 32-bit big-endian binary according to IEEE 754 single-precision floating-point format. + +.. function:: to_ieee754_64(double) -> varbinary + + Encodes ``double`` in a 64-bit big-endian binary according to IEEE 754 double-precision floating-point format. + +.. function:: crc32(binary) -> bigint + + Computes the CRC-32 of ``binary``. For general purpose hashing, use + :func:`xxhash64`, as it is much faster and produces a better quality hash. + .. function:: md5(binary) -> varbinary Computes the md5 hash of ``binary``. diff --git a/presto-docs/src/main/sphinx/functions/comparison.rst b/presto-docs/src/main/sphinx/functions/comparison.rst index 9c3da56ed3b8a..1b6d08ce9826b 100644 --- a/presto-docs/src/main/sphinx/functions/comparison.rst +++ b/presto-docs/src/main/sphinx/functions/comparison.rst @@ -113,11 +113,11 @@ The following types are supported: ``TIMESTAMP WITH TIME ZONE``, ``DATE`` -.. function:: greatest(value1, value2) -> [same as input] +.. function:: greatest(value1, value2, ..., valueN) -> [same as input] Returns the largest of the provided values. -.. function:: least(value1, value2) -> [same as input] +.. function:: least(value1, value2, ..., valueN) -> [same as input] Returns the smallest of the provided values. @@ -153,9 +153,3 @@ Expression Meaning ==================== =========== ``ANY`` and ``SOME`` have the same meaning and can be used interchangeably. - -.. note:: - - Currently, the expression ``A`` in ``A = ANY (...)`` or ``A <> ALL (...)`` - must not be ``NULL`` for any of the queried rows. Otherwise, the query will fail. - This limitation is needed to ensure correct results and may be dropped in the future. diff --git a/presto-docs/src/main/sphinx/functions/datetime.rst b/presto-docs/src/main/sphinx/functions/datetime.rst index 5c758de953d70..defac4f571cf5 100644 --- a/presto-docs/src/main/sphinx/functions/datetime.rst +++ b/presto-docs/src/main/sphinx/functions/datetime.rst @@ -156,6 +156,32 @@ Unit Description Returns ``timestamp2 - timestamp1`` expressed in terms of ``unit``. +Duration Function +----------------- + +The ``parse_duration`` function supports the following units: + +======= ============= +Unit Description +======= ============= +``ns`` Nanoseconds +``us`` Microseconds +``ms`` Milliseconds +``s`` Seconds +``m`` Minutes +``h`` Hours +``d`` Days +======= ============= + +.. function:: parse_duration(string) -> interval + + Parses ``string`` of format ``value unit`` into an interval, where + ``value`` is fractional number of ``unit`` values:: + + SELECT parse_duration('42.8ms'); -- 0 00:00:00.043 + SELECT parse_duration('3.81 d'); -- 3 19:26:24.000 + SELECT parse_duration('5m'); -- 0 00:05:00.000 + MySQL Date Functions -------------------- @@ -168,11 +194,11 @@ Specifier Description ========= =========== ``%a`` Abbreviated weekday name (``Sun`` .. ``Sat``) ``%b`` Abbreviated month name (``Jan`` .. ``Dec``) -``%c`` Month, numeric (``0`` .. ``12``) +``%c`` Month, numeric (``1`` .. ``12``) [#z]_ ``%D`` Day of the month with English suffix (``0th``, ``1st``, ``2nd``, ``3rd``, ...) -``%d`` Day of the month, numeric (``00`` .. ``31``) -``%e`` Day of the month, numeric (``0`` .. ``31``) -``%f`` Fraction of second (6 digits for printing: ``000000`` .. ``999000``; 1 - 9 digits for parsing: ``0`` .. ``999999999`` [#f]_) +``%d`` Day of the month, numeric (``01`` .. ``31``) [#z]_ +``%e`` Day of the month, numeric (``1`` .. ``31``) [#z]_ +``%f`` Fraction of second (6 digits for printing: ``000000`` .. ``999000``; 1 - 9 digits for parsing: ``0`` .. ``999999999``) [#f]_ ``%H`` Hour (``00`` .. ``23``) ``%h`` Hour (``01`` .. ``12``) ``%I`` Hour (``01`` .. ``12``) @@ -181,7 +207,7 @@ Specifier Description ``%k`` Hour (``0`` .. ``23``) ``%l`` Hour (``1`` .. ``12``) ``%M`` Month name (``January`` .. ``December``) -``%m`` Month, numeric (``00`` .. ``12``) +``%m`` Month, numeric (``01`` .. ``12``) [#z]_ ``%p`` ``AM`` or ``PM`` ``%r`` Time, 12-hour (``hh:mm:ss`` followed by ``AM`` or ``PM``) ``%S`` Seconds (``00`` .. ``59``) @@ -192,7 +218,7 @@ Specifier Description ``%V`` Week (``01`` .. ``53``), where Sunday is the first day of the week; used with ``%X`` ``%v`` Week (``01`` .. ``53``), where Monday is the first day of the week; used with ``%x`` ``%W`` Weekday name (``Sunday`` .. ``Saturday``) -``%w`` Day of the week (``0`` .. ``6``), where Sunday is the first day of the week +``%w`` Day of the week (``0`` .. ``6``), where Sunday is the first day of the week [#w]_ ``%X`` Year for the week where Sunday is the first day of the week, numeric, four digits; used with ``%V`` ``%x`` Year for the week, where Monday is the first day of the week, numeric, four digits; used with ``%v`` ``%Y`` Year, numeric, four digits @@ -203,6 +229,8 @@ Specifier Description .. [#f] Timestamp is truncated to milliseconds. .. [#y] When parsing, two-digit year format assumes range ``1970`` .. ``2069``, so "70" will result in year ``1970`` but "69" will produce ``2069``. +.. [#w] This specifier is not supported yet. Consider using :func:`day_of_week` (it uses ``1-7`` instead of ``0-6``). +.. [#z] This specifier does not support ``0`` as a month or day. .. warning:: The following specifiers are not currently supported: ``%D %U %u %V %w %X`` diff --git a/presto-docs/src/main/sphinx/language.rst b/presto-docs/src/main/sphinx/language.rst index 63f9d14fdeaa4..a19e4b60733b5 100644 --- a/presto-docs/src/main/sphinx/language.rst +++ b/presto-docs/src/main/sphinx/language.rst @@ -6,3 +6,4 @@ SQL Language :maxdepth: 1 language/types + language/reserved diff --git a/presto-docs/src/main/sphinx/language/reserved.rst b/presto-docs/src/main/sphinx/language/reserved.rst new file mode 100644 index 0000000000000..9cf05af04d9e8 --- /dev/null +++ b/presto-docs/src/main/sphinx/language/reserved.rst @@ -0,0 +1,80 @@ +================= +Reserved Keywords +================= + +The following table lists all of the keywords that are reserved in Presto, +along with their status in the SQL standard. These reserved keywords must +be quoted (using double quotes) in order to be used as an identifier. + +============================== ============= ============= +Keyword SQL:2016 SQL-92 +============================== ============= ============= +``ALTER`` reserved reserved +``AND`` reserved reserved +``AS`` reserved reserved +``BETWEEN`` reserved reserved +``BY`` reserved reserved +``CASE`` reserved reserved +``CAST`` reserved reserved +``CONSTRAINT`` reserved reserved +``CREATE`` reserved reserved +``CROSS`` reserved reserved +``CUBE`` reserved +``CURRENT_DATE`` reserved reserved +``CURRENT_TIME`` reserved reserved +``CURRENT_TIMESTAMP`` reserved reserved +``DEALLOCATE`` reserved reserved +``DELETE`` reserved reserved +``DESCRIBE`` reserved reserved +``DISTINCT`` reserved reserved +``DROP`` reserved reserved +``ELSE`` reserved reserved +``END`` reserved reserved +``ESCAPE`` reserved reserved +``EXCEPT`` reserved reserved +``EXECUTE`` reserved reserved +``EXISTS`` reserved reserved +``EXTRACT`` reserved reserved +``FALSE`` reserved reserved +``FOR`` reserved reserved +``FROM`` reserved reserved +``FULL`` reserved reserved +``GROUP`` reserved reserved +``GROUPING`` reserved +``HAVING`` reserved reserved +``IN`` reserved reserved +``INNER`` reserved reserved +``INSERT`` reserved reserved +``INTERSECT`` reserved reserved +``INTO`` reserved reserved +``IS`` reserved reserved +``JOIN`` reserved reserved +``LEFT`` reserved reserved +``LIKE`` reserved reserved +``LOCALTIME`` reserved +``LOCALTIMESTAMP`` reserved +``NATURAL`` reserved reserved +``NORMALIZE`` reserved +``NOT`` reserved reserved +``NULL`` reserved reserved +``ON`` reserved reserved +``OR`` reserved reserved +``ORDER`` reserved reserved +``OUTER`` reserved reserved +``PREPARE`` reserved reserved +``RECURSIVE`` reserved +``RIGHT`` reserved reserved +``ROLLUP`` reserved +``SELECT`` reserved reserved +``TABLE`` reserved reserved +``THEN`` reserved reserved +``TRUE`` reserved reserved +``UESCAPE`` reserved +``UNION`` reserved reserved +``UNNEST`` reserved +``USING`` reserved reserved +``VALUES`` reserved reserved +``WHEN`` reserved reserved +``WHERE`` reserved reserved +``WITH`` reserved reserved +============================== ============= ============= diff --git a/presto-docs/src/main/sphinx/language/types.rst b/presto-docs/src/main/sphinx/language/types.rst index b75fa34fe58ad..05a1bd8dac799 100644 --- a/presto-docs/src/main/sphinx/language/types.rst +++ b/presto-docs/src/main/sphinx/language/types.rst @@ -31,7 +31,8 @@ INTEGER ------- A 32-bit signed two's complement integer with a minimum value of - ``-2^31`` and a maximum value of ``2^31 - 1``. + ``-2^31`` and a maximum value of ``2^31 - 1``. The name INT is + also available for this type. BIGINT ------ diff --git a/presto-docs/src/main/sphinx/release.rst b/presto-docs/src/main/sphinx/release.rst index 94f0edb6e2631..2bde0e3eeab31 100644 --- a/presto-docs/src/main/sphinx/release.rst +++ b/presto-docs/src/main/sphinx/release.rst @@ -5,6 +5,11 @@ Release Notes .. toctree:: :maxdepth: 1 + release/release-0.179 + release/release-0.178 + release/release-0.177 + release/release-0.176 + release/release-0.175 release/release-0.174 release/release-0.173 release/release-0.172 diff --git a/presto-docs/src/main/sphinx/release/release-0.156.rst b/presto-docs/src/main/sphinx/release/release-0.156.rst index a00c1325d9bb5..5e78305c50a10 100644 --- a/presto-docs/src/main/sphinx/release/release-0.156.rst +++ b/presto-docs/src/main/sphinx/release/release-0.156.rst @@ -2,6 +2,12 @@ Release 0.156 ============= +.. warning:: + + Query may incorrectly produce ``NULL`` when no row qualifies for the aggregation + if the ``optimize_mixed_distinct_aggregations`` session property or + the ``optimizer.optimize-mixed-distinct-aggregations`` config option is enabled. + General Changes --------------- diff --git a/presto-docs/src/main/sphinx/release/release-0.175.rst b/presto-docs/src/main/sphinx/release/release-0.175.rst new file mode 100644 index 0000000000000..4e249c9840718 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.175.rst @@ -0,0 +1,36 @@ +============= +Release 0.175 +============= + +General Changes +--------------- + +* Fix *"position is not valid"* query execution failures. +* Fix memory accounting bug that can potentially cause ``OutOfMemoryError``. +* Fix regression that could cause certain queries involving ``UNION`` and + ``GROUP BY`` or ``JOIN`` to fail during planning. +* Fix planning failure for ``GROUP BY`` queries containing correlated + subqueries in the ``SELECT`` clause. +* Fix execution failure for certain ``DELETE`` queries. +* Reduce occurrences of *"Method code too large"* errors. +* Reduce memory utilization for certain queries involving ``ORDER BY``. +* Improve performance of map subscript from O(n) to O(1) when the map is + produced by an eligible operation, including the map constructor and + Hive readers (except ORC and optimized Parquet). More read and write + operations will take advantage of this in future releases. +* Add ``enable_intermediate_aggregations`` session property to enable the + use of intermediate aggregations within un-grouped aggregations. +* Add support for ``INTERVAL`` data type to :func:`avg` and :func:`sum` aggregation functions. +* Add support for ``INT`` as an alias for the ``INTEGER`` data type. +* Add resource group information to query events. + +Hive Changes +------------ + +* Make table creation metastore operations idempotent, which allows + recovery when retrying timeouts or other errors. + +MongoDB Changes +--------------- + +* Rename ``mongodb.connection-per-host`` config option to ``mongodb.connections-per-host``. diff --git a/presto-docs/src/main/sphinx/release/release-0.176.rst b/presto-docs/src/main/sphinx/release/release-0.176.rst new file mode 100644 index 0000000000000..369ca619bb255 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.176.rst @@ -0,0 +1,29 @@ +============= +Release 0.176 +============= + +General Changes +--------------- + +* Fix an issue where a query (and some of its tasks) continues to + consume CPU/memory on the coordinator and workers after the query fails. +* Fix a regression that cause the GC overhead and pauses to increase significantly when processing maps. +* Fix a memory tracking bug that causes the memory to be overestimated for ``GROUP BY`` queries on ``bigint`` columns. +* Improve the performance of the :func:`transform_values` function. +* Add support for casting from ``JSON`` to ``REAL`` type. +* Add :func:`parse_duration` function. + +MySQL Changes +------------- + +* Disallow having a database in the ``connection-url`` config property. + +Accumulo Changes +---------------- + +* Decrease planning time by fetching index metrics in parallel. + +MongoDB Changes +--------------- + +* Allow predicate pushdown for ObjectID. diff --git a/presto-docs/src/main/sphinx/release/release-0.177.rst b/presto-docs/src/main/sphinx/release/release-0.177.rst new file mode 100644 index 0000000000000..0f5081bb313f9 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.177.rst @@ -0,0 +1,77 @@ +============= +Release 0.177 +============= + +.. warning:: + + Query may incorrectly produce ``NULL`` when no row qualifies for the aggregation + if the ``optimize_mixed_distinct_aggregations`` session property or + the ``optimizer.optimize-mixed-distinct-aggregations`` config option is enabled. + This optimization was introduced in Presto version 0.156. + +General Changes +--------------- + +* Fix correctness issue when performing range comparisons over columns of type ``CHAR``. +* Fix correctness issue due to mishandling of nulls and non-deterministic expressions in + inequality joins unless ``fast_inequality_join`` is disabled. +* Fix excessive GC overhead caused by lambda expressions. There are still known GC issues + with captured lambda expressions. This will be fixed in a future release. +* Check for duplicate columns in ``CREATE TABLE`` before asking the connector to create + the table. This improves the error message for most connectors and will prevent errors + for connectors that do not perform validation internally. +* Add support for null values on the left-hand side of a semijoin (i.e., ``IN`` predicate + with subqueries). +* Add ``SHOW STATS`` to display table and query statistics. +* Improve implicit coercion support for functions involving lambda. Specifically, this makes + it easier to use the :func:`reduce` function. +* Improve plans for queries involving ``ORDER BY`` and ``LIMIT`` by avoiding unnecessary + data exchanges. +* Improve performance of queries containing window functions with identical ``PARTITION BY`` + and ``ORDER BY`` clauses. +* Improve performance of certain queries involving ``OUTER JOIN`` and aggregations, or + containing certain forms of correlated subqueries. This optimization is experimental + and can be turned on via the ``push_aggregation_through_join`` session property or the + ``optimizer.push-aggregation-through-join`` config option. +* Improve performance of certain queries involving joins and aggregations. This optimization + is experimental and can be turned on via the ``push_partial_aggregation_through_join`` + session property. +* Improve error message when a lambda expression has a different number of arguments than expected. +* Improve error message when certain invalid ``GROUP BY`` expressions containing lambda expressions. + +Hive Changes +------------ + +* Fix handling of trailing spaces for the ``CHAR`` type when reading RCFile. +* Allow inserts into tables that have more partitions than the partitions-per-scan limit. +* Add support for exposing Hive table statistics to the engine. This option is experimental and + can be turned on via the ``statistics_enabled`` session property. +* Ensure file name is always present for error messages about corrupt ORC files. + +Cassandra Changes +----------------- + +* Remove caching of metadata in the Cassandra connector. Metadata caching makes Presto violate + the consistency defined by the Cassandra cluster. It's also unnecessary because the Cassandra + driver internally caches metadata. The ``cassandra.max-schema-refresh-threads``, + ``cassandra.schema-cache-ttl`` and ``cassandra.schema-refresh-interval`` config options have + been removed. +* Fix intermittent issue in the connection retry mechanism. + +Web UI Changes +-------------- + +* Change cluster HUD realtime statistics to be aggregated across all running queries. +* Change parallelism statistic on cluster HUD to be averaged per-worker. +* Fix bug that always showed indeterminate progress bar in query list view. +* Change running drivers statistic to exclude blocked drivers. +* Change unit of CPU and scheduled time rate sparklines to seconds on query details page. +* Change query details page refresh interval to three seconds. +* Add uptime and connected status indicators to every page. + +CLI Changes +----------- + +* Add support for preprocessing commands. When the ``PRESTO_PREPROCESSOR`` environment + variable is set, all commands are piped through the specified program before being sent to + the Presto server. diff --git a/presto-docs/src/main/sphinx/release/release-0.178.rst b/presto-docs/src/main/sphinx/release/release-0.178.rst new file mode 100644 index 0000000000000..2912a9775fcab --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.178.rst @@ -0,0 +1,31 @@ +============= +Release 0.178 +============= + +General Changes +--------------- + +* Fix various memory accounting bugs, which reduces the likelihood of full GCs/OOMs. +* Fix a regression that causes queries that use the keyword "stats" to fail to parse. +* Fix an issue where a query does not get cleaned up on the coordinator after query failure. +* Add ability to cast to ``JSON`` from ``REAL``, ``TINYINT`` or ``SMALLINT``. +* Add support for ``GROUPING`` operation to :ref:`complex grouping operations`. +* Add support for correlated subqueries in ``IN`` predicates. +* Add :func:`to_ieee754_32` and :func:`to_ieee754_64` functions. + +Hive Changes +------------ + +* Fix high CPU usage due to schema caching when reading Avro files. +* Preserve decompression error causes when decoding ORC files. + +Memory Connector Changes +------------------------ + +* Fix a bug that prevented creating empty tables. + +SPI Changes +----------- + +* Make environment available to resource group configuration managers. +* Add additional performance statistics to query completion event. diff --git a/presto-docs/src/main/sphinx/release/release-0.179.rst b/presto-docs/src/main/sphinx/release/release-0.179.rst new file mode 100644 index 0000000000000..727faa9455637 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.179.rst @@ -0,0 +1,40 @@ +============= +Release 0.179 +============= + +General Changes +--------------- + +* Fix issue which could cause incorrect results when processing dictionary encoded data. If the expression + can fail on bad input, the results from filtered-out rows containing bad input may be included in the query + output. See `#8262 `_ for more details. +* Fix planning failure when similar expressions appear in the ``ORDER BY`` clause of a query that + contains ``ORDER BY`` and ``LIMIT``. +* Fix planning failure when ``GROUPING()`` is used with the ``legacy_order_by`` session property set to ``true``. +* Fix parsing failure when ``NFD``, ``NFC``, ``NFKD`` or ``NFKC`` are used as identifiers. +* Fix a memory leak on the coordinator that manifests itself with canceled queries. +* Fix excessive GC overhead caused by captured lambda expressions. +* Reduce the memory usage of map/array aggregation functions. +* Redact sensitive config property values in the server log. +* Update timezone database to version 2017b. +* Add :func:`repeat` function. +* Add :func:`crc32` function. +* Add file based global security, which can be configured with the ``etc/access-control.properties`` + and ``security.config-file`` config properties. See :doc:`/security/built-in-system-access-control` + for more details. +* Add support for configuring query runtime and queueing time limits to resource groups. + +Hive Changes +------------ + +* Fail queries that access encrypted S3 objects that do not have their unencrypted content lengths set in their metadata. + +JDBC Driver Changes +------------------- + +* Add support for setting query timeout through ``Statement.setQueryTimeout()``. + +SPI Changes +----------- + +* Add grantee and revokee to ``GRANT`` and ``REVOKE`` security checks. diff --git a/presto-docs/src/main/sphinx/rest.rst b/presto-docs/src/main/sphinx/rest.rst index d65b9927a7595..3deb9709d31b8 100644 --- a/presto-docs/src/main/sphinx/rest.rst +++ b/presto-docs/src/main/sphinx/rest.rst @@ -13,7 +13,6 @@ responses. .. toctree:: :maxdepth: 1 - rest/execute rest/node rest/query rest/stage @@ -30,12 +29,6 @@ The Presto REST API contains several, high-level resources that correspond to the components of a Presto installation. -Execute Resource - - The execute resource is what the client sends queries to. It is - available at the path ``/v1/execute``, and accepts a query as a POST - and returns JSON. - Query Resource The query resource takes a SQL query. It is available at the path diff --git a/presto-docs/src/main/sphinx/rest/execute.rst b/presto-docs/src/main/sphinx/rest/execute.rst deleted file mode 100644 index 098d4de1d2972..0000000000000 --- a/presto-docs/src/main/sphinx/rest/execute.rst +++ /dev/null @@ -1,63 +0,0 @@ -================ -Execute Resource -================ - -.. function:: POST /v1/execute - - :Body: SQL Query to execute - :Header "X-Presto-User": User to execute statement on behalf of (optional) - :Header "X-Presto-Source": Source of query - :Header "X-Presto-Catalog": Catalog to execute query against - :Header "X-Presto-Schema": Schema to execute query against - - Call this to execute a SQL statement as an alternative to running - ``/v1/statement``. Where ``/v1/statement`` will return a - ``nextUri`` and details about a running query, the ``/v1/execute`` - call will simply execute the SQL statement posted to it and return - the result set. This service will not return updates about query - status or details about stages and tasks. It simply executes a - query and returns the result. - - The sample request and response shown below demonstrate how the - execute call works. Once you post a SQL statement to /v1/execute it - returns a set of columns describing an array of data items. This - trivial executes a "show functions" statement. - - **Example request**: - - .. sourcecode:: http - - POST /v1/execute HTTP/1.1 - Host: localhost:8001 - X-Presto-Schema: jmx - X-Presto-User: tobrie1 - X-Presto-Catalog: jmx - Content-Type: text/html - Content-Length: 14 - - show functions - - **Example response**: - - .. sourcecode:: http - - HTTP/1.1 200 OK - Content-Type: application/json - X-Content-Type-Options: nosniff - Transfer-Encoding: chunked - - {"columns": - [ - {"name":"Function","type":"varchar"}, - {"name":"Return Type","type":"varchar"}, - {"name":"Argument Types","type":"varchar"}, - {"name":"Function Type","type":"varchar"}, - {"name":"Description","type":"varchar"} - ], - "data": - [ - ["abs","bigint","bigint","scalar","absolute value"], - ["abs","double","double","scalar","absolute value"], - ... - ] - }; diff --git a/presto-docs/src/main/sphinx/security.rst b/presto-docs/src/main/sphinx/security.rst index 921bd18f09542..27b08492d4d06 100644 --- a/presto-docs/src/main/sphinx/security.rst +++ b/presto-docs/src/main/sphinx/security.rst @@ -9,3 +9,5 @@ Security security/cli security/ldap security/tls + security/built-in-system-access-control + security/internal-communication diff --git a/presto-docs/src/main/sphinx/security/built-in-system-access-control.rst b/presto-docs/src/main/sphinx/security/built-in-system-access-control.rst new file mode 100644 index 0000000000000..5012d8fc5cde1 --- /dev/null +++ b/presto-docs/src/main/sphinx/security/built-in-system-access-control.rst @@ -0,0 +1,111 @@ +============================== +Built-in System Access Control +============================== + +A system access control plugin enforces authorization at a global level, +before any connector level authorization. You can either use one of the built-in +plugins in Presto or provide your own by following the guidelines in +:doc:`/develop/system-access-control`. Presto offers three built-in plugins: + +================================================== ============================================================ +Plugin Name Description +================================================== ============================================================ +``allow-all`` (default value) All operations are permitted. + +``read-only`` Operations that read data or metadata are permitted, but + none of the operations that write data or metadata are + allowed. See :ref:`read-only-system-access-control` for + details. + +``file`` Authorization checks are enforced using a config file + specified by the configuration property ``security.config-file``. + See :ref:`file-based-system-access-control` for details. +================================================== ============================================================ + +Allow All System Access Control +=============================== + +All operations are permitted under this plugin. This plugin is enabled by default. + +.. _read-only-system-access-control: + +Read Only System Access Control +=============================== + +Under this plugin, you are allowed to execute any operation that reads data or +metadata, such as ``SELECT`` or ``SHOW``. Setting system level or catalog level +session properties is also permitted. However, any operation that writes data or +metadata, such as ``CREATE``, ``INSERT`` or ``DELETE``, is prohibited. +To use this plugin, add an ``etc/access-control.properties`` +file with the following contents: + +.. code-block:: none + + access-control.name=read-only + +.. _file-based-system-access-control: + +File Based System Access Control +================================ + +This plugin allows you to specify access control rules in a file. To use this +plugin, add an ``etc/access-control.properties`` file containing two required +properties: ``access-control.name``, which must be equal to ``file``, and +``security.config-file``, which must be equal to the location of the config file. +For example, if a config file named ``rules.json`` +resides in ``etc``, add an ``etc/access-control.properties`` with the following +contents: + +.. code-block:: none + + access-control.name=file + security.config-file=etc/rules.json + +The config file consists of a list of access control rules in JSON format. The +rules are matched in the order specified in the file. All +regular expressions default to ``.*`` if not specified. + +This plugin currently only supports catalog access control rules. If you want +to limit access on a system level in any other way, you must implement a custom +SystemAccessControl plugin (see :doc:`/develop/system-access-control`). + +Catalog Rules +------------- + +These rules govern the catalogs particular users can access. The user is +granted access to a catalog based on the first matching rule. If no rule +matches, access is denied. Each rule is composed of the following fields: + +* ``user`` (optional): regex to match against user name. +* ``catalog`` (optional): regex to match against catalog name. +* ``allowed`` (required): boolean indicating whether a user has access to the catalog + +.. note:: + + By default, all users have access to the ``system`` catalog. You can + override this behavior by adding a rule. + +For example, if you want to allow only the user ``admin`` to access the +``mysql`` and the ``system`` catalog, allow all users to access the ``hive`` +catalog, and deny all other access, you can use the following rules: + +.. code-block:: json + + { + "catalogs": [ + { + "user": "admin", + "catalog": "(mysql|system)", + "allow": true + }, + { + "catalog": "hive", + "allow": true + }, + { + "catalog": "system", + "allow": false + } + ] + } + diff --git a/presto-docs/src/main/sphinx/security/internal-communication.rst b/presto-docs/src/main/sphinx/security/internal-communication.rst new file mode 100644 index 0000000000000..4a1092a8fc84d --- /dev/null +++ b/presto-docs/src/main/sphinx/security/internal-communication.rst @@ -0,0 +1,156 @@ +============================= +Secure Internal Communication +============================= + +The Presto cluster can be configured to use secured communication. Communication +between Presto nodes can be secured with SSL/TLS. + +Internal SSL/TLS configuration +------------------------------ + +SSL/TLS is configured in the `config.properties` file. The SSL/TLS on the +worker and coordinator nodes are configured using the same set of properties. +Every node in the cluster must be configured. Nodes that have not been +configured, or are configured incorrectly, will not be able to communicate with +other nodes in the cluster. + +To enable SSL/TLS for Presto internal communication, do the following: + +1. Disable HTTP endpoint. + + .. code-block:: none + + http-server.http.enabled=false + + .. warning:: + + You can enable HTTPS while leaving HTTP enabled. In most cases this is a + security hole. If you are certain you want to use this configuration, you + should consider using an firewall to limit access to the HTTP endpoint to + only those hosts that should be allowed to use it. + +2. Configure the cluster to communicate using the fully qualified domain name (fqdn) + of the cluster nodes. This can be done in either of the following ways: + + - If the DNS service is configured properly, we can just let the nodes to + introduce themselves to the coordinator using the hostname taken from + the system configuration (`hostname --fqdn`) + + .. code-block:: none + + node.internal-address-source=FQDN + + - It is also possible to specify each node's fully-qualified hostname manually. + This will be different for every host. Hosts should be in the same domain to + make it easy to create the correct SSL/TLS certificates. + e.g.: `coordinator.example.com`, `worker1.example.com`, `worker2.example.com`. + + .. code-block:: none + + node.internal-address= + + +3. Generate a Java Keystore File. Every Presto node must be able to connect to + any other node within the same cluster. It is possible to create unique + certificates for every node using the fully-qualified hostname of each host, + create a keystore that contains all the public keys for all of the hosts, + and specify it for the client (`http-client.https.keystore.path`). In most + cases it will be simpler to use a wildcard in the certificate as shown + below. + + .. code-block:: none + + keytool -genkeypair -alias example.com -keyalg RSA -keystore keystore.jks + Enter keystore password: + Re-enter new password: + What is your first and last name? + [Unknown]: *.example.com + What is the name of your organizational unit? + [Unknown]: + What is the name of your organization? + [Unknown]: + What is the name of your City or Locality? + [Unknown]: + What is the name of your State or Province? + [Unknown]: + What is the two-letter country code for this unit? + [Unknown]: + Is CN=*.example.com, OU=Unknown, O=Unknown, L=Unknown, ST=Unknown, C=Unknown correct? + [no]: yes + + Enter key password for + (RETURN if same as keystore password): + + .. Note: Replace `example.com` with the appropriate domain. + +4. Distribute the Java Keystore File across the Presto cluster. + +5. Enable the HTTPS endpoint. + + .. code-block:: none + + http-server.https.enabled=true + http-server.https.port= + http-server.https.keystore.path= + http-server.https.keystore.key= + +6. Change the discovery uri to HTTPS. + + .. code-block:: none + + discovery.uri=https://: + +7. Configure the internal communication to require HTTPS. + + .. code-block:: none + + internal-communication.https.required=true + +8. Configure the internal communication to use the Java keystore file. + + .. code-block:: none + + internal-communication.https.keystore.path= + internal-communication.https.keystore.key= + + +Performance with SSL/TLS enabled +-------------------------------- + +Enabling encryption impacts performance. The performance degradation can vary +based on the environment, queries, and concurrency. + +For queries that do not require transferring too much data between the Presto +nodes (e.g. `SELECT count(*) FROM table`), the performance impact is negligible. + +However, for CPU intensive queries which require a considerable amount of data +to be transferred between the nodes (for example, distributed joins, aggregations and +window functions, which require repartitioning), the performance impact might be +considerable. The slowdown may vary from 10% to even 100%+, depending on the network +traffic and the CPU utilization. + +Advanced Performance Tuning +--------------------------- + +In some cases, changing the source of random numbers will improve performance +significantly. + +By default, TLS encryption uses the `/dev/urandom` system device as a source of entropy. +This device has limited throughput, so on environments with high network bandwidth +(e.g. InfiniBand), it may become a bottleneck. In such situations, it is recommended to try +to switch the random number generator algorithm to `SHA1PRNG`, by setting it via +`http-server.https.secure-random-algorithm` property in `config.properties` on the coordinator +and all of the workers: + + .. code-block:: none + + http-server.https.secure-random-algorithm=SHA1PRNG + +Be aware that this algorithm takes the initial seed from +the blocking `/dev/random` device. For environments that do not have enough entropy to seed +the `SHAPRNG` algorithm, the source can be changed to `/dev/urandom` +by adding the `java.security.egd` property to `jvm.config`: + + .. code-block:: none + + -Djava.security.egd=file:/dev/urandom diff --git a/presto-docs/src/main/sphinx/sql/select.rst b/presto-docs/src/main/sphinx/sql/select.rst index 07a9bb3682b1d..441476f042484 100644 --- a/presto-docs/src/main/sphinx/sql/select.rst +++ b/presto-docs/src/main/sphinx/sql/select.rst @@ -364,6 +364,48 @@ only unique grouping sets are generated:: The default set quantifier is ``ALL``. +**GROUPING Operation** + +``grouping(col1, ..., colN) -> bigint`` + +The grouping operation returns a bit set converted to decimal, indicating which columns are present in a +grouping. It must be used in conjunction with ``GROUPING SETS``, ``ROLLUP``, ``CUBE`` or ``GROUP BY`` +and its arguments must match exactly the columns referenced in the corresponding ``GROUPING SETS``, +``ROLLUP``, ``CUBE`` or ``GROUP BY`` clause. + +To compute the resulting bit set for a particular row, bits are assigned to the argument columns with +the rightmost column being the least significant bit. For a given grouping, a bit is set to 0 if the +corresponding column is included in the grouping and to 1 otherwise. For example, consider the query +below:: + + SELECT origin_state, origin_zip, destination_state, sum(package_weight), + grouping(origin_state, origin_zip, destination_state) + FROM shipping + GROUP BY GROUPING SETS ( + (origin_state), + (origin_state, origin_zip), + (destination_state)); + +.. code-block:: none + + origin_state | origin_zip | destination_state | _col3 | _col4 + --------------+------------+-------------------+-------+------- + California | NULL | NULL | 1397 | 3 + New Jersey | NULL | NULL | 225 | 3 + New York | NULL | NULL | 3 | 3 + California | 94131 | NULL | 60 | 1 + New Jersey | 7081 | NULL | 225 | 1 + California | 90210 | NULL | 1337 | 1 + New York | 10002 | NULL | 3 | 1 + NULL | NULL | New Jersey | 58 | 6 + NULL | NULL | Connecticut | 1562 | 6 + NULL | NULL | Colorado | 5 | 6 + (10 rows) + +The first grouping in the above result only includes the ``origin_state`` column and excludes +the ``origin_zip`` and ``destination_state`` columns. The bit set constructed for that grouping +is ``011`` where the most significant bit represents ``origin_state``. + HAVING Clause ------------- @@ -763,12 +805,6 @@ standard rules for nulls. The subquery must produce exactly one column:: FROM nation WHERE regionkey IN (SELECT regionkey FROM region) -.. note:: - - Currently, the expression on the left hand side of ``IN`` must - not be ``NULL`` for any of the queried rows. Otherwise, the query will fail. - This limitation is needed to ensure correct results and may be dropped in the future. - Scalar Subquery ^^^^^^^^^^^^^^^ diff --git a/presto-example-http/pom.xml b/presto-example-http/pom.xml index be5d3b86e5b46..0ae0b5e9b3c31 100644 --- a/presto-example-http/pom.xml +++ b/presto-example-http/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-example-http diff --git a/presto-hive-cdh4/pom.xml b/presto-hive-cdh4/pom.xml deleted file mode 100644 index 3727b81968124..0000000000000 --- a/presto-hive-cdh4/pom.xml +++ /dev/null @@ -1,122 +0,0 @@ - - - 4.0.0 - - - com.facebook.presto - presto-root - 0.175-SNAPSHOT - - - presto-hive-cdh4 - Presto - Hive Connector - CDH 4 - presto-plugin - - - ${project.parent.basedir} - - - - - com.facebook.presto - presto-hive - - - - com.facebook.presto.hadoop - hadoop-cdh4 - runtime - - - - - com.facebook.presto - presto-spi - provided - - - - io.airlift - slice - provided - - - - io.airlift - units - provided - - - - com.fasterxml.jackson.core - jackson-annotations - provided - - - - org.openjdk.jol - jol-core - provided - - - - - org.testng - testng - test - - - - io.airlift - testing - test - - - - com.facebook.presto - presto-hive - test-jar - test - - - - com.facebook.presto - presto-main - test - - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - hive,hive-s3 - - - - - - - - test-hive-cdh4 - - - - org.apache.maven.plugins - maven-surefire-plugin - - hive-s3 - - localhost - 9083 - default - - - - - - - - diff --git a/presto-hive-cdh4/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java b/presto-hive-cdh4/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java deleted file mode 100644 index d983920a1ad03..0000000000000 --- a/presto-hive-cdh4/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; - -@Test(groups = "hive-s3") -public class TestHiveClientS3 - extends AbstractTestHiveClientS3 -{ - @Parameters({ - "hive.cdh4.metastoreHost", - "hive.cdh4.metastorePort", - "hive.cdh4.databaseName", - "hive.cdh4.s3.awsAccessKey", - "hive.cdh4.s3.awsSecretKey", - "hive.cdh4.s3.writableBucket", - }) - @BeforeClass - @Override - public void setup(String host, int port, String databaseName, String awsAccessKey, String awsSecretKey, String writableBucket) - { - super.setup(host, port, databaseName, awsAccessKey, awsSecretKey, writableBucket); - } -} diff --git a/presto-hive-cdh5/pom.xml b/presto-hive-cdh5/pom.xml deleted file mode 100644 index d3f28efd2d987..0000000000000 --- a/presto-hive-cdh5/pom.xml +++ /dev/null @@ -1,122 +0,0 @@ - - - 4.0.0 - - - com.facebook.presto - presto-root - 0.175-SNAPSHOT - - - presto-hive-cdh5 - Presto - Hive Connector - CDH 5 - presto-plugin - - - ${project.parent.basedir} - - - - - com.facebook.presto - presto-hive - - - - com.facebook.presto.hadoop - hadoop-apache2 - runtime - - - - - com.facebook.presto - presto-spi - provided - - - - io.airlift - slice - provided - - - - io.airlift - units - provided - - - - com.fasterxml.jackson.core - jackson-annotations - provided - - - - org.openjdk.jol - jol-core - provided - - - - - org.testng - testng - test - - - - io.airlift - testing - test - - - - com.facebook.presto - presto-hive - test-jar - test - - - - com.facebook.presto - presto-main - test - - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - hive,hive-s3 - - - - - - - - test-hive-cdh5 - - - - org.apache.maven.plugins - maven-surefire-plugin - - hive-s3 - - localhost - 9083 - default - - - - - - - - diff --git a/presto-hive-cdh5/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java b/presto-hive-cdh5/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java deleted file mode 100644 index bbc906b4d6153..0000000000000 --- a/presto-hive-cdh5/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; - -@Test(groups = "hive-s3") -public class TestHiveClientS3 - extends AbstractTestHiveClientS3 -{ - @Parameters({ - "hive.cdh5.metastoreHost", - "hive.cdh5.metastorePort", - "hive.cdh5.databaseName", - "hive.cdh5.s3.awsAccessKey", - "hive.cdh5.s3.awsSecretKey", - "hive.cdh5.s3.writableBucket", - }) - @BeforeClass - @Override - public void setup(String host, int port, String databaseName, String awsAccessKey, String awsSecretKey, String writableBucket) - { - super.setup(host, port, databaseName, awsAccessKey, awsSecretKey, writableBucket); - } -} diff --git a/presto-hive-hadoop1/pom.xml b/presto-hive-hadoop1/pom.xml deleted file mode 100644 index 1143e04717a39..0000000000000 --- a/presto-hive-hadoop1/pom.xml +++ /dev/null @@ -1,122 +0,0 @@ - - - 4.0.0 - - - com.facebook.presto - presto-root - 0.175-SNAPSHOT - - - presto-hive-hadoop1 - Presto - Hive Connector - Apache Hadoop 1.x - presto-plugin - - - ${project.parent.basedir} - - - - - com.facebook.presto - presto-hive - - - - com.facebook.presto.hadoop - hadoop-apache1 - runtime - - - - - com.facebook.presto - presto-spi - provided - - - - io.airlift - slice - provided - - - - io.airlift - units - provided - - - - com.fasterxml.jackson.core - jackson-annotations - provided - - - - org.openjdk.jol - jol-core - provided - - - - - org.testng - testng - test - - - - io.airlift - testing - test - - - - com.facebook.presto - presto-hive - test-jar - test - - - - com.facebook.presto - presto-main - test - - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - hive,hive-s3 - - - - - - - - test-hive-hadoop1 - - - - org.apache.maven.plugins - maven-surefire-plugin - - hive-s3 - - localhost - 9083 - default - - - - - - - - diff --git a/presto-hive-hadoop1/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java b/presto-hive-hadoop1/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java deleted file mode 100644 index cbe5778a596fe..0000000000000 --- a/presto-hive-hadoop1/src/test/java/com/facebook/presto/hive/TestHiveClientS3.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; - -@Test(groups = "hive-s3") -public class TestHiveClientS3 - extends AbstractTestHiveClientS3 -{ - @Parameters({ - "hive.hadoop1.metastoreHost", - "hive.hadoop1.metastorePort", - "hive.hadoop1.databaseName", - "hive.hadoop1.s3.awsAccessKey", - "hive.hadoop1.s3.awsSecretKey", - "hive.hadoop1.s3.writableBucket", - }) - @BeforeClass - @Override - public void setup(String host, int port, String databaseName, String awsAccessKey, String awsSecretKey, String writableBucket) - { - super.setup(host, port, databaseName, awsAccessKey, awsSecretKey, writableBucket); - } -} diff --git a/presto-hive-hadoop2/pom.xml b/presto-hive-hadoop2/pom.xml index adbabf068fe06..9db506f5d5419 100644 --- a/presto-hive-hadoop2/pom.xml +++ b/presto-hive-hadoop2/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-hive-hadoop2 diff --git a/presto-hive/pom.xml b/presto-hive/pom.xml index 49b8e3de1e8ba..322bdb9476f81 100644 --- a/presto-hive/pom.xml +++ b/presto-hive/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-hive @@ -39,7 +39,7 @@ com.facebook.presto.hadoop - hadoop-cdh4 + hadoop-apache2 provided diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java b/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java index f0c6b3d50c68a..0091e26c05ba6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/BackgroundHiveSplitLoader.java @@ -77,6 +77,7 @@ import static com.facebook.presto.hive.HiveUtil.getInputFormat; import static com.facebook.presto.hive.HiveUtil.isSplittable; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkState; import static java.lang.Math.toIntExact; @@ -301,7 +302,7 @@ private void loadPartition(HivePartitionMetadata partition) TextInputFormat targetInputFormat = new TextInputFormat(); // get the configuration for the target path -- it may be a different hdfs instance Configuration targetConfiguration = hdfsEnvironment.getConfiguration(targetPath); - JobConf targetJob = new JobConf(targetConfiguration); + JobConf targetJob = toJobConf(targetConfiguration); targetJob.setInputFormat(TextInputFormat.class); targetInputFormat.configure(targetJob); FileInputFormat.setInputPaths(targetJob, targetPath); @@ -317,7 +318,7 @@ private void loadPartition(HivePartitionMetadata partition) // To support custom input formats, we want to call getSplits() // on the input format to obtain file splits. if (shouldUseFileSplitsFromInputFormat(inputFormat)) { - JobConf jobConf = new JobConf(configuration); + JobConf jobConf = toJobConf(configuration); FileInputFormat.setInputPaths(jobConf, path); InputSplit[] splits = inputFormat.getSplits(jobConf, 0); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarBinaryHiveRecordCursor.java b/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarBinaryHiveRecordCursor.java deleted file mode 100644 index 5ec4a67985abb..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarBinaryHiveRecordCursor.java +++ /dev/null @@ -1,755 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.RecordCursor; -import com.facebook.presto.spi.type.DecimalType; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeManager; -import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableSet; -import io.airlift.slice.ByteArrays; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; -import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable; -import org.apache.hadoop.hive.serde2.columnar.BytesRefWritable; -import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; -import org.apache.hadoop.hive.serde2.io.TimestampWritable; -import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef; -import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryFactory; -import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryObject; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; -import org.apache.hadoop.io.WritableUtils; -import org.apache.hadoop.mapred.RecordReader; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; -import java.util.Properties; -import java.util.Set; - -import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.REGULAR; -import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA; -import static com.facebook.presto.hive.HiveErrorCode.HIVE_CURSOR_ERROR; -import static com.facebook.presto.hive.HiveType.HIVE_BYTE; -import static com.facebook.presto.hive.HiveType.HIVE_DATE; -import static com.facebook.presto.hive.HiveType.HIVE_DOUBLE; -import static com.facebook.presto.hive.HiveType.HIVE_FLOAT; -import static com.facebook.presto.hive.HiveType.HIVE_INT; -import static com.facebook.presto.hive.HiveType.HIVE_LONG; -import static com.facebook.presto.hive.HiveType.HIVE_SHORT; -import static com.facebook.presto.hive.HiveType.HIVE_TIMESTAMP; -import static com.facebook.presto.hive.HiveUtil.closeWithSuppression; -import static com.facebook.presto.hive.HiveUtil.getTableObjectInspector; -import static com.facebook.presto.hive.HiveUtil.isStructuralType; -import static com.facebook.presto.hive.util.DecimalUtils.getLongDecimalValue; -import static com.facebook.presto.hive.util.DecimalUtils.getShortDecimalValue; -import static com.facebook.presto.hive.util.SerDeUtils.getBlockObject; -import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; -import static com.facebook.presto.spi.type.Chars.isCharType; -import static com.facebook.presto.spi.type.Chars.trimSpacesAndTruncateToLength; -import static com.facebook.presto.spi.type.DateType.DATE; -import static com.facebook.presto.spi.type.Decimals.isLongDecimal; -import static com.facebook.presto.spi.type.Decimals.isShortDecimal; -import static com.facebook.presto.spi.type.DoubleType.DOUBLE; -import static com.facebook.presto.spi.type.IntegerType.INTEGER; -import static com.facebook.presto.spi.type.RealType.REAL; -import static com.facebook.presto.spi.type.SmallintType.SMALLINT; -import static com.facebook.presto.spi.type.StandardTypes.DECIMAL; -import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; -import static com.facebook.presto.spi.type.TinyintType.TINYINT; -import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; -import static com.facebook.presto.spi.type.Varchars.isVarcharType; -import static com.facebook.presto.spi.type.Varchars.truncateToLength; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static java.lang.Math.max; -import static java.lang.Math.min; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; -import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; - -class ColumnarBinaryHiveRecordCursor - implements RecordCursor -{ - private final RecordReader recordReader; - private final K key; - private final BytesRefArrayWritable value; - - private final Type[] types; - private final HiveType[] hiveTypes; - - private final ObjectInspector[] fieldInspectors; // DON'T USE THESE UNLESS EXTRACTION WILL BE SLOW ANYWAY - - private final int[] hiveColumnIndexes; - - private final boolean[] loaded; - private final boolean[] booleans; - private final long[] longs; - private final double[] doubles; - private final Slice[] slices; - private final Object[] objects; - private final boolean[] nulls; - - private final long totalBytes; - private long completedBytes; - private boolean closed; - - private final HiveDecimalWritable decimalWritable = new HiveDecimalWritable(); - - private static final byte HIVE_EMPTY_STRING_BYTE = (byte) 0xbf; - - private static final int SIZE_OF_SHORT = 2; - private static final int SIZE_OF_INT = 4; - private static final int SIZE_OF_LONG = 8; - - private static final Set VALID_HIVE_STRING_TYPES = ImmutableSet.of(PrimitiveCategory.BINARY, PrimitiveCategory.VARCHAR, PrimitiveCategory.STRING); - private static final Set VALID_HIVE_STRUCTURAL_CATEGORIES = ImmutableSet.of(Category.LIST, Category.MAP, Category.STRUCT); - - public ColumnarBinaryHiveRecordCursor(RecordReader recordReader, - long totalBytes, - Properties splitSchema, - List columns, - TypeManager typeManager) - { - requireNonNull(recordReader, "recordReader is null"); - checkArgument(totalBytes >= 0, "totalBytes is negative"); - requireNonNull(splitSchema, "splitSchema is null"); - requireNonNull(columns, "columns is null"); - - this.recordReader = recordReader; - this.totalBytes = totalBytes; - this.key = recordReader.createKey(); - this.value = recordReader.createValue(); - - int size = columns.size(); - - this.types = new Type[size]; - this.hiveTypes = new HiveType[size]; - - this.fieldInspectors = new ObjectInspector[size]; - - this.hiveColumnIndexes = new int[size]; - - this.loaded = new boolean[size]; - this.booleans = new boolean[size]; - this.longs = new long[size]; - this.doubles = new double[size]; - this.slices = new Slice[size]; - this.objects = new Object[size]; - this.nulls = new boolean[size]; - - // initialize data columns - StructObjectInspector rowInspector = getTableObjectInspector(splitSchema); - - for (int i = 0; i < columns.size(); i++) { - HiveColumnHandle column = columns.get(i); - checkState(column.getColumnType() == REGULAR, "column type must be regular"); - - types[i] = typeManager.getType(column.getTypeSignature()); - hiveTypes[i] = column.getHiveType(); - hiveColumnIndexes[i] = column.getHiveColumnIndex(); - - fieldInspectors[i] = rowInspector.getStructFieldRef(column.getName()).getFieldObjectInspector(); - } - } - - @Override - public long getTotalBytes() - { - return totalBytes; - } - - @Override - public long getCompletedBytes() - { - if (!closed) { - updateCompletedBytes(); - } - return completedBytes; - } - - @Override - public long getReadTimeNanos() - { - return 0; - } - - private void updateCompletedBytes() - { - try { - long newCompletedBytes = (long) (totalBytes * recordReader.getProgress()); - completedBytes = min(totalBytes, max(completedBytes, newCompletedBytes)); - } - catch (IOException ignored) { - } - } - - @Override - public Type getType(int field) - { - return types[field]; - } - - @Override - public boolean advanceNextPosition() - { - try { - if (closed || !recordReader.next(key, value)) { - close(); - return false; - } - - // reset loaded flags - Arrays.fill(loaded, false); - - return true; - } - catch (IOException | RuntimeException e) { - closeWithSuppression(this, e); - throw new PrestoException(HIVE_CURSOR_ERROR, e); - } - } - - @Override - public boolean getBoolean(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - validateType(fieldId, BOOLEAN); - if (!loaded[fieldId]) { - parseBooleanColumn(fieldId); - } - return booleans[fieldId]; - } - - private void parseBooleanColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseBooleanColumn(column, bytes, start, length); - } - } - - private void parseBooleanColumn(int column, byte[] bytes, int start, int length) - { - if (length > 0) { - booleans[column] = bytes[start] != 0; - nulls[column] = false; - } - else { - nulls[column] = true; - } - } - - @Override - public long getLong(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - if (!types[fieldId].equals(BIGINT) && - !types[fieldId].equals(INTEGER) && - !types[fieldId].equals(SMALLINT) && - !types[fieldId].equals(TINYINT) && - !types[fieldId].equals(DATE) && - !types[fieldId].equals(TIMESTAMP) && - !isShortDecimal(types[fieldId]) && - !types[fieldId].equals(REAL)) { - // we don't use Preconditions.checkArgument because it requires boxing fieldId, which affects inner loop performance - throw new IllegalArgumentException( - format("Expected field to be %s, %s, %s, %s, %s, %s, %s or %s , actual %s (field %s)", TINYINT, SMALLINT, INTEGER, BIGINT, DATE, TIMESTAMP, DECIMAL, REAL, types[fieldId], fieldId)); - } - if (!loaded[fieldId]) { - parseLongColumn(fieldId); - } - return longs[fieldId]; - } - - private void parseLongColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw new PrestoException(HIVE_BAD_DATA, e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseLongColumn(column, bytes, start, length); - } - } - - private void parseLongColumn(int column, byte[] bytes, int start, int length) - { - if (length == 0) { - nulls[column] = true; - return; - } - nulls[column] = false; - if (hiveTypes[column].equals(HIVE_SHORT)) { - // the file format uses big endian - checkState(length == SIZE_OF_SHORT, "Short should be 2 bytes"); - longs[column] = Short.reverseBytes(ByteArrays.getShort(bytes, start)); - } - else if (hiveTypes[column].equals(HIVE_DATE)) { - checkState(length >= 1, "Date should be at least 1 byte"); - long daysSinceEpoch = readVInt(bytes, start, length); - longs[column] = daysSinceEpoch; - } - else if (hiveTypes[column].equals(HIVE_TIMESTAMP)) { - checkState(length >= 1, "Timestamp should be at least 1 byte"); - long seconds = TimestampWritable.getSeconds(bytes, start); - long nanos = (bytes[start] >> 7) != 0 ? TimestampWritable.getNanos(bytes, start + SIZE_OF_INT) : 0; - longs[column] = (seconds * 1000) + (nanos / 1_000_000); - } - else if (hiveTypes[column].equals(HIVE_BYTE)) { - checkState(length == 1, "Byte should be 1 byte"); - longs[column] = bytes[start]; - } - else if (hiveTypes[column].equals(HIVE_INT)) { - checkState(length >= 1, "Int should be at least 1 byte"); - if (length == 1) { - longs[column] = bytes[start]; - } - else { - longs[column] = readVInt(bytes, start, length); - } - } - else if (hiveTypes[column].equals(HIVE_LONG)) { - checkState(length >= 1, "Long should be at least 1 byte"); - if (length == 1) { - longs[column] = bytes[start]; - } - else { - longs[column] = readVInt(bytes, start, length); - } - } - else if (hiveTypes[column].equals(HIVE_FLOAT)) { - // the file format uses big endian - checkState(length == SIZE_OF_INT, "Float should be 4 bytes"); - int intBits = ByteArrays.getInt(bytes, start); - longs[column] = Integer.reverseBytes(intBits); - } - else { - throw new RuntimeException(format("%s is not a valid LONG type", hiveTypes[column])); - } - } - - private static long readVInt(byte[] bytes, int start, int length) - { - long value = 0; - for (int i = 1; i < length; i++) { - value <<= 8; - value |= (bytes[start + i] & 0xFF); - } - return WritableUtils.isNegativeVInt(bytes[start]) ? ~value : value; - } - - @Override - public double getDouble(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - validateType(fieldId, DOUBLE); - if (!loaded[fieldId]) { - parseDoubleColumn(fieldId); - } - return doubles[fieldId]; - } - - private void parseDoubleColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseDoubleColumn(column, bytes, start, length); - } - } - - private void parseDoubleColumn(int column, byte[] bytes, int start, int length) - { - if (length == 0) { - nulls[column] = true; - } - else { - checkState(hiveTypes[column].equals(HIVE_DOUBLE), "%s is not a valid DOUBLE type", hiveTypes[column]); - - nulls[column] = false; - // the file format uses big endian - checkState(length == SIZE_OF_LONG, "Double should be 8 bytes"); - long longBits = ByteArrays.getLong(bytes, start); - doubles[column] = Double.longBitsToDouble(Long.reverseBytes(longBits)); - } - } - - @Override - public Slice getSlice(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - Type type = types[fieldId]; - if (!isVarcharType(type) && !isCharType(type) && !type.equals(VARBINARY) && !isStructuralType(hiveTypes[fieldId]) && !isLongDecimal(type)) { - // we don't use Preconditions.checkArgument because it requires boxing fieldId, which affects inner loop performance - throw new IllegalArgumentException(format("Expected field to be VARCHAR, CHAR, VARBINARY or DECIMAL, actual %s (field %s)", type, fieldId)); - } - - if (!loaded[fieldId]) { - parseStringColumn(fieldId); - } - return slices[fieldId]; - } - - private void parseStringColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseStringColumn(column, bytes, start, length); - } - } - - private void parseStringColumn(int column, byte[] bytes, int start, int length) - { - checkState(isValidHiveStringType(hiveTypes[column]), "%s is not a valid STRING type", hiveTypes[column]); - if (length == 0) { - nulls[column] = true; - } - else { - nulls[column] = false; - // TODO: zero length BINARY is not supported. See https://issues.apache.org/jira/browse/HIVE-2483 - if (hiveTypes[column].equals(HiveType.HIVE_STRING) && (length == 1) && bytes[start] == HIVE_EMPTY_STRING_BYTE) { - slices[column] = Slices.EMPTY_SLICE; - } - else { - Slice value = Slices.wrappedBuffer(Arrays.copyOfRange(bytes, start, start + length)); - Type type = types[column]; - if (isVarcharType(type)) { - slices[column] = truncateToLength(value, type); - } - else { - slices[column] = value; - } - } - } - } - - private void parseCharColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseCharColumn(column, bytes, start, length); - } - } - - private void parseCharColumn(int column, byte[] bytes, int start, int length) - { - if (length == 0) { - nulls[column] = true; - } - else { - nulls[column] = false; - Slice value = Slices.wrappedBuffer(Arrays.copyOfRange(bytes, start, start + length)); - Type type = types[column]; - slices[column] = trimSpacesAndTruncateToLength(value, type); - } - } - - private void parseDecimalColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseDecimalColumn(column, bytes, start, length); - } - } - - private void parseDecimalColumn(int column, byte[] bytes, int start, int length) - { - if (length == 0) { - nulls[column] = true; - } - else { - nulls[column] = false; - decimalWritable.setFromBytes(bytes, start, length); - DecimalType columnType = (DecimalType) types[column]; - if (columnType.isShort()) { - longs[column] = getShortDecimalValue(decimalWritable, columnType.getScale()); - } - else { - slices[column] = getLongDecimalValue(decimalWritable, columnType.getScale()); - } - } - } - - private boolean isValidHiveStringType(HiveType hiveType) - { - return hiveType.getCategory() == Category.PRIMITIVE - && VALID_HIVE_STRING_TYPES.contains(((PrimitiveTypeInfo) hiveType.getTypeInfo()).getPrimitiveCategory()); - } - - @Override - public Object getObject(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - Type type = types[fieldId]; - if (!isStructuralType(hiveTypes[fieldId])) { - // we don't use Preconditions.checkArgument because it requires boxing fieldId, which affects inner loop performance - throw new IllegalArgumentException(format("Expected field to be structural, actual %s (field %s)", type, fieldId)); - } - - if (!loaded[fieldId]) { - parseObjectColumn(fieldId); - } - return objects[fieldId]; - } - - private void parseObjectColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseObjectColumn(column, bytes, start, length); - } - } - - private void parseObjectColumn(int column, byte[] bytes, int start, int length) - { - checkState(VALID_HIVE_STRUCTURAL_CATEGORIES.contains(hiveTypes[column].getCategory()), "%s is not a valid STRUCTURAL type", hiveTypes[column]); - if (length == 0) { - nulls[column] = true; - } - else { - nulls[column] = false; - LazyBinaryObject lazyObject = LazyBinaryFactory.createLazyBinaryObject(fieldInspectors[column]); - ByteArrayRef byteArrayRef = new ByteArrayRef(); - byteArrayRef.setData(bytes); - lazyObject.init(byteArrayRef, start, length); - objects[column] = getBlockObject(types[column], lazyObject.getObject(), fieldInspectors[column]); - } - } - - @Override - public boolean isNull(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - if (!loaded[fieldId]) { - parseColumn(fieldId); - } - return nulls[fieldId]; - } - - private void parseColumn(int column) - { - Type type = types[column]; - if (BOOLEAN.equals(type)) { - parseBooleanColumn(column); - } - else if (BIGINT.equals(type)) { - parseLongColumn(column); - } - else if (INTEGER.equals(type)) { - parseLongColumn(column); - } - else if (SMALLINT.equals(type)) { - parseLongColumn(column); - } - else if (TINYINT.equals(type)) { - parseLongColumn(column); - } - else if (DOUBLE.equals(type)) { - parseDoubleColumn(column); - } - else if (REAL.equals(type)) { - parseLongColumn(column); - } - else if (isVarcharType(type) || VARBINARY.equals(type)) { - parseStringColumn(column); - } - else if (isCharType(type)) { - parseCharColumn(column); - } - else if (isStructuralType(hiveTypes[column])) { - parseObjectColumn(column); - } - else if (DATE.equals(type)) { - parseLongColumn(column); - } - else if (TIMESTAMP.equals(type)) { - parseLongColumn(column); - } - else if (type instanceof DecimalType) { - parseDecimalColumn(column); - } - else { - throw new UnsupportedOperationException("Unsupported column type: " + type); - } - } - - private void validateType(int fieldId, Type type) - { - if (!types[fieldId].equals(type)) { - // we don't use Preconditions.checkArgument because it requires boxing fieldId, which affects inner loop performance - throw new IllegalArgumentException(format("Expected field to be %s, actual %s (field %s)", type, types[fieldId], fieldId)); - } - } - - @Override - public void close() - { - // some hive input formats are broken and bad things can happen if you close them multiple times - if (closed) { - return; - } - closed = true; - - updateCompletedBytes(); - - try { - recordReader.close(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarBinaryHiveRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarBinaryHiveRecordCursorProvider.java deleted file mode 100644 index 66879ad8424a4..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarBinaryHiveRecordCursorProvider.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.ConnectorSession; -import com.facebook.presto.spi.RecordCursor; -import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.spi.type.TypeManager; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable; -import org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe; -import org.apache.hadoop.mapred.RecordReader; -import org.joda.time.DateTimeZone; - -import javax.inject.Inject; - -import java.util.List; -import java.util.Optional; -import java.util.Properties; - -import static com.facebook.presto.hive.HiveUtil.isDeserializerClass; -import static java.util.Objects.requireNonNull; - -public class ColumnarBinaryHiveRecordCursorProvider - implements HiveRecordCursorProvider -{ - private final HdfsEnvironment hdfsEnvironment; - - @Inject - public ColumnarBinaryHiveRecordCursorProvider(HdfsEnvironment hdfsEnvironment) - { - this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); - } - - @Override - public Optional createRecordCursor( - String clientId, - Configuration configuration, - ConnectorSession session, - Path path, - long start, - long length, - Properties schema, - List columns, - TupleDomain effectivePredicate, - DateTimeZone hiveStorageTimeZone, - TypeManager typeManager) - { - if (!isDeserializerClass(schema, LazyBinaryColumnarSerDe.class)) { - return Optional.empty(); - } - - RecordReader recordReader = hdfsEnvironment.doAs(session.getUser(), - () -> HiveUtil.createRecordReader(configuration, path, start, length, schema, columns)); - - return Optional.of(new ColumnarBinaryHiveRecordCursor<>( - bytesRecordReader(recordReader), - length, - schema, - columns, - typeManager)); - } - - @SuppressWarnings("unchecked") - private static RecordReader bytesRecordReader(RecordReader recordReader) - { - return (RecordReader) recordReader; - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarTextHiveRecordCursor.java b/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarTextHiveRecordCursor.java deleted file mode 100644 index 330dfbe936010..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarTextHiveRecordCursor.java +++ /dev/null @@ -1,657 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.RecordCursor; -import com.facebook.presto.spi.block.Block; -import com.facebook.presto.spi.type.DecimalType; -import com.facebook.presto.spi.type.Decimals; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeManager; -import com.google.common.base.Throwables; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; -import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable; -import org.apache.hadoop.hive.serde2.columnar.BytesRefWritable; -import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef; -import org.apache.hadoop.hive.serde2.lazy.LazyFactory; -import org.apache.hadoop.hive.serde2.lazy.LazyObject; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.mapred.RecordReader; -import org.joda.time.DateTimeZone; - -import java.io.IOException; -import java.math.BigDecimal; -import java.util.Arrays; -import java.util.List; -import java.util.Properties; - -import static com.facebook.presto.hive.HiveBooleanParser.isFalse; -import static com.facebook.presto.hive.HiveBooleanParser.isTrue; -import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.REGULAR; -import static com.facebook.presto.hive.HiveDecimalParser.parseHiveDecimal; -import static com.facebook.presto.hive.HiveErrorCode.HIVE_CURSOR_ERROR; -import static com.facebook.presto.hive.HiveUtil.base64Decode; -import static com.facebook.presto.hive.HiveUtil.closeWithSuppression; -import static com.facebook.presto.hive.HiveUtil.getTableObjectInspector; -import static com.facebook.presto.hive.HiveUtil.isStructuralType; -import static com.facebook.presto.hive.HiveUtil.parseHiveDate; -import static com.facebook.presto.hive.HiveUtil.parseHiveTimestamp; -import static com.facebook.presto.hive.NumberParser.parseDouble; -import static com.facebook.presto.hive.NumberParser.parseFloat; -import static com.facebook.presto.hive.NumberParser.parseLong; -import static com.facebook.presto.hive.util.SerDeUtils.getBlockObject; -import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; -import static com.facebook.presto.spi.type.Chars.isCharType; -import static com.facebook.presto.spi.type.Chars.trimSpacesAndTruncateToLength; -import static com.facebook.presto.spi.type.DateType.DATE; -import static com.facebook.presto.spi.type.Decimals.isShortDecimal; -import static com.facebook.presto.spi.type.DoubleType.DOUBLE; -import static com.facebook.presto.spi.type.IntegerType.INTEGER; -import static com.facebook.presto.spi.type.RealType.REAL; -import static com.facebook.presto.spi.type.SmallintType.SMALLINT; -import static com.facebook.presto.spi.type.StandardTypes.DECIMAL; -import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; -import static com.facebook.presto.spi.type.TinyintType.TINYINT; -import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; -import static com.facebook.presto.spi.type.Varchars.isVarcharType; -import static com.facebook.presto.spi.type.Varchars.truncateToLength; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static java.lang.Float.floatToRawIntBits; -import static java.lang.Math.max; -import static java.lang.Math.min; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -class ColumnarTextHiveRecordCursor - implements RecordCursor -{ - private final RecordReader recordReader; - private final K key; - private final BytesRefArrayWritable value; - - private final Type[] types; - private final HiveType[] hiveTypes; - - private final ObjectInspector[] fieldInspectors; // DON'T USE THESE UNLESS EXTRACTION WILL BE SLOW ANYWAY - - private final int[] hiveColumnIndexes; - - private final boolean[] loaded; - private final boolean[] booleans; - private final long[] longs; - private final double[] doubles; - private final Slice[] slices; - private final Object[] objects; - private final boolean[] nulls; - - private final long totalBytes; - private final DateTimeZone hiveStorageTimeZone; - - private long completedBytes; - private boolean closed; - - public ColumnarTextHiveRecordCursor( - RecordReader recordReader, - long totalBytes, - Properties splitSchema, - List columns, - DateTimeZone hiveStorageTimeZone, - TypeManager typeManager) - { - requireNonNull(recordReader, "recordReader is null"); - checkArgument(totalBytes >= 0, "totalBytes is negative"); - requireNonNull(splitSchema, "splitSchema is null"); - requireNonNull(columns, "columns is null"); - requireNonNull(hiveStorageTimeZone, "hiveStorageTimeZone is null"); - - this.recordReader = recordReader; - this.totalBytes = totalBytes; - this.key = recordReader.createKey(); - this.value = recordReader.createValue(); - this.hiveStorageTimeZone = hiveStorageTimeZone; - - int size = columns.size(); - - this.types = new Type[size]; - this.hiveTypes = new HiveType[size]; - - this.fieldInspectors = new ObjectInspector[size]; - - this.hiveColumnIndexes = new int[size]; - - this.loaded = new boolean[size]; - this.booleans = new boolean[size]; - this.longs = new long[size]; - this.doubles = new double[size]; - this.slices = new Slice[size]; - this.objects = new Object[size]; - this.nulls = new boolean[size]; - - // initialize data columns - StructObjectInspector rowInspector = getTableObjectInspector(splitSchema); - - for (int i = 0; i < columns.size(); i++) { - HiveColumnHandle column = columns.get(i); - checkState(column.getColumnType() == REGULAR, "column type must be regular"); - - types[i] = typeManager.getType(column.getTypeSignature()); - hiveTypes[i] = column.getHiveType(); - hiveColumnIndexes[i] = column.getHiveColumnIndex(); - - fieldInspectors[i] = rowInspector.getStructFieldRef(column.getName()).getFieldObjectInspector(); - } - } - - @Override - public long getTotalBytes() - { - return totalBytes; - } - - @Override - public long getCompletedBytes() - { - if (!closed) { - updateCompletedBytes(); - } - return completedBytes; - } - - @Override - public long getReadTimeNanos() - { - return 0; - } - - private void updateCompletedBytes() - { - try { - long newCompletedBytes = (long) (totalBytes * recordReader.getProgress()); - completedBytes = min(totalBytes, max(completedBytes, newCompletedBytes)); - } - catch (IOException ignored) { - } - } - - @Override - public Type getType(int field) - { - return types[field]; - } - - @Override - public boolean advanceNextPosition() - { - try { - if (closed || !recordReader.next(key, value)) { - close(); - return false; - } - - // reset loaded flags - Arrays.fill(loaded, false); - - return true; - } - catch (IOException | RuntimeException e) { - closeWithSuppression(this, e); - throw new PrestoException(HIVE_CURSOR_ERROR, e); - } - } - - @Override - public boolean getBoolean(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - validateType(fieldId, boolean.class); - if (!loaded[fieldId]) { - parseBooleanColumn(fieldId); - } - return booleans[fieldId]; - } - - private void parseBooleanColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseBooleanColumn(column, bytes, start, length); - } - } - - private void parseBooleanColumn(int column, byte[] bytes, int start, int length) - { - boolean wasNull; - if (isTrue(bytes, start, length)) { - booleans[column] = true; - wasNull = false; - } - else if (isFalse(bytes, start, length)) { - booleans[column] = false; - wasNull = false; - } - else { - wasNull = true; - } - nulls[column] = wasNull; - } - - @Override - public long getLong(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - if (!types[fieldId].equals(BIGINT) && - !types[fieldId].equals(INTEGER) && - !types[fieldId].equals(SMALLINT) && - !types[fieldId].equals(TINYINT) && - !types[fieldId].equals(DATE) && - !types[fieldId].equals(TIMESTAMP) && - !isShortDecimal(types[fieldId]) && - !types[fieldId].equals(REAL)) { - // we don't use Preconditions.checkArgument because it requires boxing fieldId, which affects inner loop performance - throw new IllegalArgumentException( - format("Expected field to be %s, %s, %s, %s, %s, %s, %s or %s , actual %s (field %s)", TINYINT, SMALLINT, INTEGER, BIGINT, DECIMAL, DATE, TIMESTAMP, REAL, types[fieldId], fieldId)); - } - - if (!loaded[fieldId]) { - parseLongColumn(fieldId); - } - return longs[fieldId]; - } - - private void parseLongColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseLongColumn(column, bytes, start, length); - } - } - - private void parseLongColumn(int column, byte[] bytes, int start, int length) - { - boolean wasNull; - if (length == 0 || (length == "\\N".length() && bytes[start] == '\\' && bytes[start + 1] == 'N')) { - wasNull = true; - } - else if (hiveTypes[column].equals(HiveType.HIVE_DATE)) { - String value = new String(bytes, start, length); - longs[column] = parseHiveDate(value); - wasNull = false; - } - else if (hiveTypes[column].equals(HiveType.HIVE_TIMESTAMP)) { - String value = new String(bytes, start, length); - longs[column] = parseHiveTimestamp(value, hiveStorageTimeZone); - wasNull = false; - } - else if (hiveTypes[column].equals(HiveType.HIVE_FLOAT)) { - longs[column] = floatToRawIntBits(parseFloat(bytes, start, length)); - wasNull = false; - } - else { - longs[column] = parseLong(bytes, start, length); - wasNull = false; - } - nulls[column] = wasNull; - } - - @Override - public double getDouble(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - validateType(fieldId, double.class); - if (!loaded[fieldId]) { - parseDoubleColumn(fieldId); - } - return doubles[fieldId]; - } - - private void parseDoubleColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseDoubleColumn(column, bytes, start, length); - } - } - - private void parseDoubleColumn(int column, byte[] bytes, int start, int length) - { - boolean wasNull; - if (length == 0 || (length == "\\N".length() && bytes[start] == '\\' && bytes[start + 1] == 'N')) { - wasNull = true; - } - else { - doubles[column] = parseDouble(bytes, start, length); - wasNull = false; - } - nulls[column] = wasNull; - } - - @Override - public Slice getSlice(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - validateType(fieldId, Slice.class); - if (!loaded[fieldId]) { - parseStringColumn(fieldId); - } - return slices[fieldId]; - } - - private void parseStringColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseStringColumn(column, bytes, start, length); - } - } - - private void parseStringColumn(int column, byte[] bytes, int start, int length) - { - boolean wasNull; - if (length == "\\N".length() && bytes[start] == '\\' && bytes[start + 1] == 'N') { - wasNull = true; - } - else { - Type type = types[column]; - Slice value = Slices.wrappedBuffer(Arrays.copyOfRange(bytes, start, start + length)); - if (isVarcharType(type)) { - slices[column] = truncateToLength(value, type); - } - else if (isCharType(type)) { - slices[column] = trimSpacesAndTruncateToLength(value, type); - } - // this is unbelievably stupid but Hive base64 encodes binary data in a binary file format - else if (type.equals(VARBINARY)) { - // and yes we end up with an extra copy here because the Base64 only handles whole arrays - slices[column] = base64Decode(value.getBytes()); - } - else { - slices[column] = value; - } - wasNull = false; - } - nulls[column] = wasNull; - } - - private void parseDecimalColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseDecimalColumn(column, bytes, start, length); - } - } - - private void parseDecimalColumn(int column, byte[] bytes, int start, int length) - { - boolean wasNull; - if (length == 0 || (length == "\\N".length() && bytes[start] == '\\' && bytes[start + 1] == 'N')) { - wasNull = true; - } - else { - DecimalType columnType = (DecimalType) types[column]; - BigDecimal decimal = parseHiveDecimal(bytes, start, length, columnType); - - if (columnType.isShort()) { - longs[column] = decimal.unscaledValue().longValue(); - } - else { - slices[column] = Decimals.encodeUnscaledValue(decimal.unscaledValue()); - } - - wasNull = false; - } - nulls[column] = wasNull; - } - - @Override - public Object getObject(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - validateType(fieldId, Block.class); - if (!loaded[fieldId]) { - parseObjectColumn(fieldId); - } - return objects[fieldId]; - } - - private void parseObjectColumn(int column) - { - loaded[column] = true; - - if (hiveColumnIndexes[column] >= value.size()) { - // this partition may contain fewer fields than what's declared in the schema - // this happens when additional columns are added to the hive table after a partition has been created - nulls[column] = true; - } - else { - BytesRefWritable fieldData = value.unCheckedGet(hiveColumnIndexes[column]); - - byte[] bytes; - try { - bytes = fieldData.getData(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - - int start = fieldData.getStart(); - int length = fieldData.getLength(); - - parseObjectColumn(column, bytes, start, length); - } - } - - private void parseObjectColumn(int column, byte[] bytes, int start, int length) - { - boolean wasNull; - if (length == "\\N".length() && bytes[start] == '\\' && bytes[start + 1] == 'N') { - wasNull = true; - } - else { - LazyObject lazyObject = LazyFactory.createLazyObject(fieldInspectors[column]); - ByteArrayRef byteArrayRef = new ByteArrayRef(); - byteArrayRef.setData(bytes); - lazyObject.init(byteArrayRef, start, length); - objects[column] = getBlockObject(types[column], lazyObject.getObject(), fieldInspectors[column]); - wasNull = false; - } - nulls[column] = wasNull; - } - - @Override - public boolean isNull(int fieldId) - { - checkState(!closed, "Cursor is closed"); - - if (!loaded[fieldId]) { - parseColumn(fieldId); - } - return nulls[fieldId]; - } - - private void parseColumn(int column) - { - Type type = types[column]; - if (type.equals(BOOLEAN)) { - parseBooleanColumn(column); - } - else if (type.equals(BIGINT)) { - parseLongColumn(column); - } - else if (type.equals(INTEGER)) { - parseLongColumn(column); - } - else if (type.equals(SMALLINT)) { - parseLongColumn(column); - } - else if (type.equals(TINYINT)) { - parseLongColumn(column); - } - else if (type.equals(REAL)) { - parseLongColumn(column); - } - else if (type.equals(DOUBLE)) { - parseDoubleColumn(column); - } - else if (isVarcharType(type) || VARBINARY.equals(type) || isCharType(type)) { - parseStringColumn(column); - } - else if (isStructuralType(hiveTypes[column])) { - parseObjectColumn(column); - } - else if (type.equals(DATE)) { - parseLongColumn(column); - } - else if (type.equals(TIMESTAMP)) { - parseLongColumn(column); - } - else if (type instanceof DecimalType) { - parseDecimalColumn(column); - } - else { - throw new UnsupportedOperationException("Unsupported column type: " + type); - } - } - - private void validateType(int fieldId, Class type) - { - if (!types[fieldId].getJavaType().equals(type)) { - // we don't use Preconditions.checkArgument because it requires boxing fieldId, which affects inner loop performance - throw new IllegalArgumentException(String.format("Expected field to be %s, actual %s (field %s)", type, types[fieldId], fieldId)); - } - } - - @Override - public void close() - { - // some hive input formats are broken and bad things can happen if you close them multiple times - if (closed) { - return; - } - closed = true; - - updateCompletedBytes(); - - try { - recordReader.close(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarTextHiveRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarTextHiveRecordCursorProvider.java deleted file mode 100644 index 7618b459cced7..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ColumnarTextHiveRecordCursorProvider.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.ConnectorSession; -import com.facebook.presto.spi.RecordCursor; -import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.spi.type.TypeManager; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable; -import org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe; -import org.apache.hadoop.mapred.RecordReader; -import org.joda.time.DateTimeZone; - -import javax.inject.Inject; - -import java.util.List; -import java.util.Optional; -import java.util.Properties; - -import static com.facebook.presto.hive.HiveUtil.isDeserializerClass; -import static java.util.Objects.requireNonNull; - -public class ColumnarTextHiveRecordCursorProvider - implements HiveRecordCursorProvider -{ - private final HdfsEnvironment hdfsEnvironment; - - @Inject - public ColumnarTextHiveRecordCursorProvider(HdfsEnvironment hdfsEnvironment) - { - this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); - } - - @Override - public Optional createRecordCursor( - String clientId, - Configuration configuration, - ConnectorSession session, - Path path, - long start, - long length, - Properties schema, - List columns, - TupleDomain effectivePredicate, - DateTimeZone hiveStorageTimeZone, - TypeManager typeManager) - { - if (!isDeserializerClass(schema, ColumnarSerDe.class)) { - return Optional.empty(); - } - - RecordReader recordReader = hdfsEnvironment.doAs(session.getUser(), - () -> HiveUtil.createRecordReader(configuration, path, start, length, schema, columns)); - - return Optional.of(new ColumnarTextHiveRecordCursor<>( - columnarTextRecordReader(recordReader), - length, - schema, - columns, - hiveStorageTimeZone, - typeManager)); - } - - @SuppressWarnings("unchecked") - private static RecordReader columnarTextRecordReader(RecordReader recordReader) - { - return (RecordReader) recordReader; - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/FileFormatDataSourceStats.java b/presto-hive/src/main/java/com/facebook/presto/hive/FileFormatDataSourceStats.java new file mode 100644 index 0000000000000..1184b19464ffd --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/FileFormatDataSourceStats.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import io.airlift.stats.DistributionStat; +import io.airlift.stats.TimeStat; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +public class FileFormatDataSourceStats +{ + private final DistributionStat readBytes = new DistributionStat(); + private final DistributionStat loadedBlockBytes = new DistributionStat(); + private final DistributionStat maxCombinedBytesPerRow = new DistributionStat(); + private final TimeStat time0Bto100KB = new TimeStat(MILLISECONDS); + private final TimeStat time100KBto1MB = new TimeStat(MILLISECONDS); + private final TimeStat time1MBto10MB = new TimeStat(MILLISECONDS); + private final TimeStat time10MBPlus = new TimeStat(MILLISECONDS); + + @Managed + @Nested + public DistributionStat getReadBytes() + { + return readBytes; + } + + @Managed + @Nested + public DistributionStat getLoadedBlockBytes() + { + return loadedBlockBytes; + } + + @Managed + @Nested + public DistributionStat getMaxCombinedBytesPerRow() + { + return maxCombinedBytesPerRow; + } + + @Managed + @Nested + public TimeStat get0Bto100KB() + { + return time0Bto100KB; + } + + @Managed + @Nested + public TimeStat get100KBto1MB() + { + return time100KBto1MB; + } + + @Managed + @Nested + public TimeStat get1MBto10MB() + { + return time1MBto10MB; + } + + @Managed + @Nested + public TimeStat get10MBPlus() + { + return time10MBPlus; + } + + public void readDataBytesPerSecond(long bytes, long nanos) + { + readBytes.add(bytes); + if (bytes < 100 * 1024) { + time0Bto100KB.add(nanos, NANOSECONDS); + } + else if (bytes < 1024 * 1024) { + time100KBto1MB.add(nanos, NANOSECONDS); + } + else if (bytes < 10 * 1024 * 1024) { + time1MBto10MB.add(nanos, NANOSECONDS); + } + else { + time10MBPlus.add(nanos, NANOSECONDS); + } + } + + public void addLoadedBlockSize(long bytes) + { + loadedBlockBytes.add(bytes); + } + + public void addMaxCombinedBytesPerRow(long bytes) + { + maxCombinedBytesPerRow.add(bytes); + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursor.java b/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursor.java index 87ce874498406..0b2c739f8e38c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursor.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursor.java @@ -23,15 +23,18 @@ import com.google.common.base.Throwables; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; -import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.serde2.Deserializer; import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.BinaryComparable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.RecordReader; import org.joda.time.DateTimeZone; @@ -359,24 +362,27 @@ private void parseStringColumn(int column) nulls[column] = true; } else { - Object fieldValue = ((PrimitiveObjectInspector) fieldInspectors[column]).getPrimitiveJavaObject(fieldData); + Object fieldValue = ((PrimitiveObjectInspector) fieldInspectors[column]).getPrimitiveWritableObject(fieldData); checkState(fieldValue != null, "fieldValue should not be null"); - Slice value; - if (fieldValue instanceof String) { - value = Slices.utf8Slice((String) fieldValue); + BinaryComparable hiveValue; + if (fieldValue instanceof Text) { + hiveValue = (Text) fieldValue; } - else if (fieldValue instanceof byte[]) { - value = Slices.wrappedBuffer((byte[]) fieldValue); + else if (fieldValue instanceof BytesWritable) { + hiveValue = (BytesWritable) fieldValue; } - else if (fieldValue instanceof HiveVarchar) { - value = Slices.utf8Slice(((HiveVarchar) fieldValue).getValue()); + else if (fieldValue instanceof HiveVarcharWritable) { + hiveValue = ((HiveVarcharWritable) fieldValue).getTextValue(); } - else if (fieldValue instanceof HiveChar) { - value = Slices.utf8Slice(((HiveChar) fieldValue).getValue()); + else if (fieldValue instanceof HiveCharWritable) { + hiveValue = ((HiveCharWritable) fieldValue).getTextValue(); } else { throw new IllegalStateException("unsupported string field type: " + fieldValue.getClass().getName()); } + + // create a slice view over the hive value and trim to character limits + Slice value = Slices.wrappedBuffer(hiveValue.getBytes(), 0, hiveValue.getLength()); Type type = types[column]; if (isVarcharType(type)) { value = truncateToLength(value, type); @@ -384,7 +390,9 @@ else if (fieldValue instanceof HiveChar) { if (isCharType(type)) { value = trimSpacesAndTruncateToLength(value, type); } - slices[column] = value; + + // store a copy of the bytes, since the hive reader can reuse the underlying buffer + slices[column] = Slices.copyOf(value); nulls[column] = false; } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java b/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java index fb38275f51fc8..6b277f4141296 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationUpdater.java @@ -133,7 +133,7 @@ private static Configuration readConfiguration(List resourcePaths) return result; } - public void updateConfiguration(PrestoHadoopConfiguration config) + public void updateConfiguration(Configuration config) { copy(resourcesConfiguration, config); @@ -251,5 +251,11 @@ public void reloadCachedMappings() { // no-op } + + @Override + public void reloadCachedMappings(List names) + { + // no-op + } } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java index bea8e12b58e85..716d9414bd647 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java @@ -96,19 +96,14 @@ public class HiveClientConfig private DataSize orcMaxMergeDistance = new DataSize(1, MEGABYTE); private DataSize orcMaxBufferSize = new DataSize(8, MEGABYTE); private DataSize orcStreamBufferSize = new DataSize(8, MEGABYTE); + private DataSize orcMaxReadBlockSize = new DataSize(16, MEGABYTE); - private boolean rcfileOptimizedReaderEnabled = true; - private boolean rcfileOptimizedWriterEnabled; + private boolean rcfileOptimizedWriterEnabled = true; + private boolean rcfileWriterValidate; private HiveMetastoreAuthenticationType hiveMetastoreAuthenticationType = HiveMetastoreAuthenticationType.NONE; - private String hiveMetastoreServicePrincipal; - private String hiveMetastoreClientPrincipal; - private String hiveMetastoreClientKeytab; - private HdfsAuthenticationType hdfsAuthenticationType = HdfsAuthenticationType.NONE; private boolean hdfsImpersonationEnabled; - private String hdfsPrestoPrincipal; - private String hdfsPrestoKeytab; private boolean skipDeletionForAlter; @@ -656,29 +651,28 @@ public HiveClientConfig setOrcStreamBufferSize(DataSize orcStreamBufferSize) return this; } - public boolean isOrcBloomFiltersEnabled() + @NotNull + public DataSize getOrcMaxReadBlockSize() { - return orcBloomFiltersEnabled; + return orcMaxReadBlockSize; } - @Config("hive.orc.bloom-filters.enabled") - public HiveClientConfig setOrcBloomFiltersEnabled(boolean orcBloomFiltersEnabled) + @Config("hive.orc.max-read-block-size") + public HiveClientConfig setOrcMaxReadBlockSize(DataSize orcMaxReadBlockSize) { - this.orcBloomFiltersEnabled = orcBloomFiltersEnabled; + this.orcMaxReadBlockSize = orcMaxReadBlockSize; return this; } - @Deprecated - public boolean isRcfileOptimizedReaderEnabled() + public boolean isOrcBloomFiltersEnabled() { - return rcfileOptimizedReaderEnabled; + return orcBloomFiltersEnabled; } - @Deprecated - @Config("hive.rcfile-optimized-reader.enabled") - public HiveClientConfig setRcfileOptimizedReaderEnabled(boolean rcfileOptimizedReaderEnabled) + @Config("hive.orc.bloom-filters.enabled") + public HiveClientConfig setOrcBloomFiltersEnabled(boolean orcBloomFiltersEnabled) { - this.rcfileOptimizedReaderEnabled = rcfileOptimizedReaderEnabled; + this.orcBloomFiltersEnabled = orcBloomFiltersEnabled; return this; } @@ -696,6 +690,19 @@ public HiveClientConfig setRcfileOptimizedWriterEnabled(boolean rcfileOptimizedW return this; } + public boolean isRcfileWriterValidate() + { + return rcfileWriterValidate; + } + + @Config("hive.rcfile.writer.validate") + @ConfigDescription("Validate RCFile after write by re-reading the whole file") + public HiveClientConfig setRcfileWriterValidate(boolean rcfileWriterValidate) + { + this.rcfileWriterValidate = rcfileWriterValidate; + return this; + } + public boolean isAssumeCanonicalPartitionKeys() { return assumeCanonicalPartitionKeys; @@ -727,6 +734,7 @@ public enum HiveMetastoreAuthenticationType KERBEROS } + @NotNull public HiveMetastoreAuthenticationType getHiveMetastoreAuthenticationType() { return hiveMetastoreAuthenticationType; @@ -740,51 +748,13 @@ public HiveClientConfig setHiveMetastoreAuthenticationType(HiveMetastoreAuthenti return this; } - public String getHiveMetastoreServicePrincipal() - { - return hiveMetastoreServicePrincipal; - } - - @Config("hive.metastore.service.principal") - @ConfigDescription("Hive Metastore service principal") - public HiveClientConfig setHiveMetastoreServicePrincipal(String hiveMetastoreServicePrincipal) - { - this.hiveMetastoreServicePrincipal = hiveMetastoreServicePrincipal; - return this; - } - - public String getHiveMetastoreClientPrincipal() - { - return hiveMetastoreClientPrincipal; - } - - @Config("hive.metastore.client.principal") - @ConfigDescription("Hive Metastore client principal") - public HiveClientConfig setHiveMetastoreClientPrincipal(String hiveMetastoreClientPrincipal) - { - this.hiveMetastoreClientPrincipal = hiveMetastoreClientPrincipal; - return this; - } - - public String getHiveMetastoreClientKeytab() - { - return hiveMetastoreClientKeytab; - } - - @Config("hive.metastore.client.keytab") - @ConfigDescription("Hive Metastore client keytab location") - public HiveClientConfig setHiveMetastoreClientKeytab(String hiveMetastoreClientKeytab) - { - this.hiveMetastoreClientKeytab = hiveMetastoreClientKeytab; - return this; - } - public enum HdfsAuthenticationType { NONE, KERBEROS, } + @NotNull public HdfsAuthenticationType getHdfsAuthenticationType() { return hdfsAuthenticationType; @@ -811,32 +781,6 @@ public HiveClientConfig setHdfsImpersonationEnabled(boolean hdfsImpersonationEna return this; } - public String getHdfsPrestoPrincipal() - { - return hdfsPrestoPrincipal; - } - - @Config("hive.hdfs.presto.principal") - @ConfigDescription("Presto principal used to access HDFS") - public HiveClientConfig setHdfsPrestoPrincipal(String hdfsPrestoPrincipal) - { - this.hdfsPrestoPrincipal = hdfsPrestoPrincipal; - return this; - } - - public String getHdfsPrestoKeytab() - { - return hdfsPrestoKeytab; - } - - @Config("hive.hdfs.presto.keytab") - @ConfigDescription("Presto keytab used to access HDFS") - public HiveClientConfig setHdfsPrestoKeytab(String hdfsPrestoKeytab) - { - this.hdfsPrestoKeytab = hdfsPrestoKeytab; - return this; - } - public boolean isSkipDeletionForAlter() { return skipDeletionForAlter; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java index 5af9598f9dc9b..d9bd8a01bbbe4 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java @@ -81,7 +81,7 @@ public void configure(Binder binder) binder.bind(HiveTableProperties.class).in(Scopes.SINGLETON); binder.bind(NamenodeStats.class).in(Scopes.SINGLETON); - newExporter(binder).export(NamenodeStats.class).as(generatedNameOf(NamenodeStats.class)); + newExporter(binder).export(NamenodeStats.class).as(generatedNameOf(NamenodeStats.class, connectorId)); binder.bind(HiveMetastoreClientFactory.class).in(Scopes.SINGLETON); binder.bind(HiveCluster.class).to(StaticHiveCluster.class).in(Scopes.SINGLETON); @@ -93,8 +93,6 @@ public void configure(Binder binder) Multibinder recordCursorProviderBinder = newSetBinder(binder, HiveRecordCursorProvider.class); recordCursorProviderBinder.addBinding().to(ParquetRecordCursorProvider.class).in(Scopes.SINGLETON); - recordCursorProviderBinder.addBinding().to(ColumnarTextHiveRecordCursorProvider.class).in(Scopes.SINGLETON); - recordCursorProviderBinder.addBinding().to(ColumnarBinaryHiveRecordCursorProvider.class).in(Scopes.SINGLETON); recordCursorProviderBinder.addBinding().to(GenericHiveRecordCursorProvider.class).in(Scopes.SINGLETON); newSetBinder(binder, EventClient.class).addBinding().to(HiveEventClient.class).in(Scopes.SINGLETON); @@ -110,6 +108,9 @@ public void configure(Binder binder) jsonCodecBinder(binder).bindJsonCodec(PartitionUpdate.class); + binder.bind(FileFormatDataSourceStats.class).in(Scopes.SINGLETON); + newExporter(binder).export(FileFormatDataSourceStats.class).as(generatedNameOf(FileFormatDataSourceStats.class, connectorId)); + Multibinder pageSourceFactoryBinder = newSetBinder(binder, HivePageSourceFactory.class); pageSourceFactoryBinder.addBinding().to(OrcPageSourceFactory.class).in(Scopes.SINGLETON); pageSourceFactoryBinder.addBinding().to(DwrfPageSourceFactory.class).in(Scopes.SINGLETON); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java index e48b429c6c859..23563bf47d0c0 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.hive; -import com.google.common.collect.ImmutableClassToInstanceMap; import org.apache.hadoop.conf.Configuration; import javax.inject.Inject; @@ -45,7 +44,7 @@ public class HiveHdfsConfiguration @Override protected Configuration initialValue() { - PrestoHadoopConfiguration configuration = new PrestoHadoopConfiguration(ImmutableClassToInstanceMap.of()); + Configuration configuration = new Configuration(false); copy(INITIAL_CONFIGURATION, configuration); updater.updateConfiguration(configuration); return configuration; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java index 5703ec7c1f0ab..586067009fb43 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java @@ -23,6 +23,7 @@ import com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore.WriteMode; import com.facebook.presto.hive.metastore.StorageFormat; import com.facebook.presto.hive.metastore.Table; +import com.facebook.presto.hive.statistics.HiveStatisticsProvider; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorInsertTableHandle; @@ -51,6 +52,7 @@ import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.security.PrivilegeInfo; +import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.google.common.annotations.VisibleForTesting; @@ -83,6 +85,7 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import static com.facebook.presto.hive.HiveBucketing.getHiveBucketHandle; import static com.facebook.presto.hive.HiveColumnHandle.BUCKET_COLUMN_NAME; import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.HIDDEN; import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; @@ -98,6 +101,7 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR; import static com.facebook.presto.hive.HivePartitionManager.extractPartitionKeyValues; import static com.facebook.presto.hive.HiveSessionProperties.isBucketExecutionEnabled; +import static com.facebook.presto.hive.HiveSessionProperties.isStatisticsEnabled; import static com.facebook.presto.hive.HiveTableProperties.BUCKETED_BY_PROPERTY; import static com.facebook.presto.hive.HiveTableProperties.BUCKET_COUNT_PROPERTY; import static com.facebook.presto.hive.HiveTableProperties.EXTERNAL_LOCATION_PROPERTY; @@ -121,17 +125,21 @@ import static com.facebook.presto.hive.HiveWriteUtils.isWritableType; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.toHivePrivilege; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; +import static com.facebook.presto.hive.metastore.MetastoreUtil.getProtectMode; +import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyOnline; import static com.facebook.presto.hive.metastore.PrincipalType.USER; import static com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore.WriteMode.DIRECT_TO_TARGET_EXISTING_DIRECTORY; import static com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore.WriteMode.DIRECT_TO_TARGET_NEW_DIRECTORY; import static com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore.WriteMode.STAGE_AND_MOVE_TO_TARGET_DIRECTORY; import static com.facebook.presto.hive.metastore.StorageFormat.VIEW_STORAGE_FORMAT; import static com.facebook.presto.hive.metastore.StorageFormat.fromHiveStorageFormat; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.INVALID_SCHEMA_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.INVALID_TABLE_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.SCHEMA_NOT_EMPTY; import static com.facebook.presto.spi.predicate.TupleDomain.withColumnDomains; +import static com.facebook.presto.spi.statistics.TableStatistics.EMPTY_STATISTICS; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; @@ -169,6 +177,7 @@ public class HiveMetadata private final HiveStorageFormat defaultStorageFormat; private final TypeTranslator typeTranslator; private final String prestoVersion; + private final HiveStatisticsProvider hiveStatisticsProvider; public HiveMetadata( String connectorId, @@ -186,7 +195,8 @@ public HiveMetadata( TableParameterCodec tableParameterCodec, JsonCodec partitionUpdateCodec, TypeTranslator typeTranslator, - String prestoVersion) + String prestoVersion, + HiveStatisticsProvider hiveStatisticsProvider) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); @@ -206,6 +216,7 @@ public HiveMetadata( this.defaultStorageFormat = requireNonNull(defaultStorageFormat, "defaultStorageFormat is null"); this.typeTranslator = requireNonNull(typeTranslator, "typeTranslator is null"); this.prestoVersion = requireNonNull(prestoVersion, "prestoVersion is null"); + this.hiveStatisticsProvider = requireNonNull(hiveStatisticsProvider, "hiveStatisticsProvider is null"); } public SemiTransactionalHiveMetastore getMetastore() @@ -223,9 +234,11 @@ public List listSchemaNames(ConnectorSession session) public HiveTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) { requireNonNull(tableName, "tableName is null"); - if (!metastore.getTable(tableName.getSchemaName(), tableName.getTableName()).isPresent()) { + Optional table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); + if (!table.isPresent()) { return null; } + verifyOnline(tableName, Optional.empty(), getProtectMode(table.get()), table.get().getParameters()); return new HiveTableHandle(connectorId, tableName.getSchemaName(), tableName.getTableName()); } @@ -347,6 +360,17 @@ public Map> listTableColumns(ConnectorSess return columns.build(); } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint) + { + if (!isStatisticsEnabled(session)) { + return EMPTY_STATISTICS; + } + List hivePartitions = partitionManager.getPartitions(metastore, tableHandle, constraint).getPartitions(); + Map tableColumns = getColumnHandles(session, tableHandle); + return hiveStatisticsProvider.getTableStatistics(session, tableHandle, hivePartitions, tableColumns); + } + private List listTables(ConnectorSession session, SchemaTablePrefix prefix) { if (prefix.getSchemaName() == null || prefix.getTableName() == null) { @@ -715,10 +739,9 @@ private List computeFileNamesForMissingBuckets(HiveStorageFormat storage // fast path for common case return ImmutableList.of(); } - JobConf conf = new JobConf(hdfsEnvironment.getConfiguration(targetPath)); + JobConf conf = toJobConf(hdfsEnvironment.getConfiguration(targetPath)); String fileExtension = HiveWriterFactory.getFileExtension(conf, fromHiveStorageFormat(storageFormat)); - Set fileNames = partitionUpdate.getFileNames().stream() - .collect(Collectors.toSet()); + Set fileNames = ImmutableSet.copyOf(partitionUpdate.getFileNames()); ImmutableList.Builder missingFileNamesBuilder = ImmutableList.builder(); for (int i = 0; i < bucketCount; i++) { String fileName = HiveWriterFactory.computeBucketedFileName(filePrefix, i) + fileExtension; @@ -733,7 +756,7 @@ private List computeFileNamesForMissingBuckets(HiveStorageFormat storage private void createEmptyFile(Path path, Table table, Optional partition, List fileNames) { - JobConf conf = new JobConf(hdfsEnvironment.getConfiguration(path)); + JobConf conf = toJobConf(hdfsEnvironment.getConfiguration(path)); Properties schema; StorageFormat format; @@ -911,6 +934,8 @@ public void createView(ConnectorSession session, SchemaTableName viewName, Strin Map properties = ImmutableMap.builder() .put(TABLE_COMMENT, "Presto View") .put(PRESTO_VIEW_FLAG, "true") + .put(PRESTO_VERSION_NAME, prestoVersion) + .put(PRESTO_QUERY_ID_NAME, session.getQueryId()) .build(); Column dummyColumn = new Column("dummy", HIVE_STRING, Optional.empty()); @@ -1179,21 +1204,25 @@ private static Domain buildColumnDomain(ColumnHandle column, List @Override public Optional getInsertLayout(ConnectorSession session, ConnectorTableHandle tableHandle) { - HivePartitionResult hivePartitionResult = partitionManager.getPartitions(metastore, tableHandle, Constraint.alwaysTrue()); - if (!hivePartitionResult.getBucketHandle().isPresent()) { + HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle; + SchemaTableName tableName = hiveTableHandle.getSchemaTableName(); + Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()) + .orElseThrow(() -> new TableNotFoundException(tableName)); + + Optional hiveBucketHandle = getHiveBucketHandle(connectorId, table); + if (!hiveBucketHandle.isPresent()) { return Optional.empty(); } if (!bucketWritingEnabled) { throw new PrestoException(NOT_SUPPORTED, "Writing to bucketed Hive table has been temporarily disabled"); } - HiveBucketHandle hiveBucketHandle = hivePartitionResult.getBucketHandle().get(); HivePartitioningHandle partitioningHandle = new HivePartitioningHandle( connectorId, - hiveBucketHandle.getBucketCount(), - hiveBucketHandle.getColumns().stream() + hiveBucketHandle.get().getBucketCount(), + hiveBucketHandle.get().getColumns().stream() .map(HiveColumnHandle::getHiveType) .collect(Collectors.toList())); - List partitionColumns = hiveBucketHandle.getColumns().stream() + List partitionColumns = hiveBucketHandle.get().getColumns().stream() .map(HiveColumnHandle::getName) .collect(Collectors.toList()); return Optional.of(new ConnectorNewTableLayout(partitioningHandle, partitionColumns)); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java index 8251b15709d64..706986fe71188 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java @@ -16,6 +16,7 @@ import com.facebook.presto.hive.metastore.CachingHiveMetastore; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore; +import com.facebook.presto.hive.statistics.MetastoreHiveStatisticsProvider; import com.facebook.presto.spi.type.TypeManager; import io.airlift.concurrent.BoundedExecutor; import io.airlift.json.JsonCodec; @@ -145,13 +146,15 @@ public HiveMetadataFactory( public HiveMetadata create() { + SemiTransactionalHiveMetastore metastore = new SemiTransactionalHiveMetastore( + hdfsEnvironment, + CachingHiveMetastore.memoizeMetastore(this.metastore, perTransactionCacheMaximumSize), // per-transaction cache + renameExecution, + skipDeletionForAlter); + return new HiveMetadata( connectorId, - new SemiTransactionalHiveMetastore( - hdfsEnvironment, - CachingHiveMetastore.memoizeMetastore(metastore, perTransactionCacheMaximumSize), // per-transaction cache - renameExecution, - skipDeletionForAlter), + metastore, hdfsEnvironment, partitionManager, timeZone, @@ -165,6 +168,7 @@ public HiveMetadata create() tableParameterCodec, partitionUpdateCodec, typeTranslator, - prestoVersion); + prestoVersion, + new MetastoreHiveStatisticsProvider(typeManager, metastore)); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java index c1a4de0d3beb5..2fd7916cc3338 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java @@ -33,7 +33,6 @@ import com.google.common.collect.Maps; import io.airlift.slice.Slice; import org.apache.hadoop.hive.common.FileUtils; -import org.apache.hadoop.hive.metastore.ProtectMode; import org.joda.time.DateTimeZone; import javax.inject.Inject; @@ -49,14 +48,14 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_EXCEEDED_PARTITION_LIMIT; import static com.facebook.presto.hive.HiveUtil.getPartitionKeyColumnHandles; import static com.facebook.presto.hive.HiveUtil.parsePartitionValue; +import static com.facebook.presto.hive.metastore.MetastoreUtil.getProtectMode; +import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyOnline; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Predicates.not; -import static com.google.common.base.Strings.isNullOrEmpty; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; -import static org.apache.hadoop.hive.metastore.ProtectMode.getProtectModeFromString; public class HivePartitionManager { @@ -160,20 +159,20 @@ public HivePartitionResult getPartitions(SemiTransactionalHiveMetastore metastor private static TupleDomain toCompactTupleDomain(TupleDomain effectivePredicate, int threshold) { - checkArgument(effectivePredicate.getDomains().isPresent()); - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry entry : effectivePredicate.getDomains().get().entrySet()) { - HiveColumnHandle hiveColumnHandle = (HiveColumnHandle) entry.getKey(); - - ValueSet values = entry.getValue().getValues(); - ValueSet compactValueSet = values.getValuesProcessor().>transform( - ranges -> ranges.getRangeCount() > threshold ? Optional.of(ValueSet.ofRanges(ranges.getSpan())) : Optional.empty(), - discreteValues -> discreteValues.getValues().size() > threshold ? Optional.of(ValueSet.all(values.getType())) : Optional.empty(), - allOrNone -> Optional.empty()) - .orElse(values); - builder.put(hiveColumnHandle, Domain.create(compactValueSet, entry.getValue().isNullAllowed())); - } + effectivePredicate.getDomains().ifPresent(domains -> { + for (Map.Entry entry : domains.entrySet()) { + HiveColumnHandle hiveColumnHandle = (HiveColumnHandle) entry.getKey(); + + ValueSet values = entry.getValue().getValues(); + ValueSet compactValueSet = values.getValuesProcessor().>transform( + ranges -> ranges.getRangeCount() > threshold ? Optional.of(ValueSet.ofRanges(ranges.getSpan())) : Optional.empty(), + discreteValues -> discreteValues.getValues().size() > threshold ? Optional.of(ValueSet.all(values.getType())) : Optional.empty(), + allOrNone -> Optional.empty()) + .orElse(values); + builder.put(hiveColumnHandle, Domain.create(compactValueSet, entry.getValue().isNullAllowed())); + } + }); return TupleDomain.withColumnDomains(builder.build()); } @@ -209,17 +208,7 @@ private Table getTable(SemiTransactionalHiveMetastore metastore, SchemaTableName throw new TableNotFoundException(tableName); } Table table = target.get(); - - String protectMode = table.getParameters().get(ProtectMode.PARAMETER_NAME); - if (protectMode != null && getProtectModeFromString(protectMode).offline) { - throw new TableOfflineException(tableName, false, null); - } - - String prestoOffline = table.getParameters().get(PRESTO_OFFLINE); - if (!isNullOrEmpty(prestoOffline)) { - throw new TableOfflineException(tableName, true, prestoOffline); - } - + verifyOnline(tableName, Optional.empty(), getProtectMode(table), table.getParameters()); return table; } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveS3Config.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveS3Config.java index 3d36aa3f3fccc..d7687a0be476b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveS3Config.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveS3Config.java @@ -16,6 +16,7 @@ import com.google.common.base.StandardSystemProperty; import io.airlift.configuration.Config; import io.airlift.configuration.ConfigDescription; +import io.airlift.configuration.ConfigSecuritySensitive; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.airlift.units.MinDataSize; @@ -73,6 +74,7 @@ public String getS3AwsSecretKey() } @Config("hive.s3.aws-secret-key") + @ConfigSecuritySensitive public HiveS3Config setS3AwsSecretKey(String s3AwsSecretKey) { this.s3AwsSecretKey = s3AwsSecretKey; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java index 06e12f423671f..98026bade193d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java @@ -33,13 +33,14 @@ public final class HiveSessionProperties private static final String ORC_MAX_MERGE_DISTANCE = "orc_max_merge_distance"; private static final String ORC_MAX_BUFFER_SIZE = "orc_max_buffer_size"; private static final String ORC_STREAM_BUFFER_SIZE = "orc_stream_buffer_size"; + private static final String ORC_MAX_READ_BLOCK_SIZE = "orc_max_read_block_size"; private static final String PARQUET_PREDICATE_PUSHDOWN_ENABLED = "parquet_predicate_pushdown_enabled"; private static final String PARQUET_OPTIMIZED_READER_ENABLED = "parquet_optimized_reader_enabled"; private static final String MAX_SPLIT_SIZE = "max_split_size"; private static final String MAX_INITIAL_SPLIT_SIZE = "max_initial_split_size"; - private static final String RCFILE_OPTIMIZED_READER_ENABLED = "rcfile_optimized_reader_enabled"; - private static final String RCFILE_OPTIMIZED_WRITER_ENABLED = "rcfile_optimized_writer_enabled"; + public static final String RCFILE_OPTIMIZED_WRITER_ENABLED = "rcfile_optimized_writer_enabled"; private static final String RCFILE_OPTIMIZED_WRITER_VALIDATE = "rcfile_optimized_writer_validate"; + private static final String STATISTICS_ENABLED = "statistics_enabled"; private final List> sessionProperties; @@ -77,6 +78,11 @@ public HiveSessionProperties(HiveClientConfig config) "ORC: Size of buffer for streaming reads", config.getOrcStreamBufferSize(), false), + dataSizeSessionProperty( + ORC_MAX_READ_BLOCK_SIZE, + "ORC: Maximum size of a block to read", + config.getOrcMaxReadBlockSize(), + false), booleanSessionProperty( PARQUET_OPTIMIZED_READER_ENABLED, "Experimental: Parquet: Enable optimized reader", @@ -97,11 +103,6 @@ public HiveSessionProperties(HiveClientConfig config) "Max initial split size", config.getMaxInitialSplitSize(), true), - booleanSessionProperty( - RCFILE_OPTIMIZED_READER_ENABLED, - "Experimental: RCFile: Enable optimized reader", - config.isRcfileOptimizedReaderEnabled(), - false), booleanSessionProperty( RCFILE_OPTIMIZED_WRITER_ENABLED, "Experimental: RCFile: Enable optimized writer", @@ -110,6 +111,11 @@ public HiveSessionProperties(HiveClientConfig config) booleanSessionProperty( RCFILE_OPTIMIZED_WRITER_VALIDATE, "Experimental: RCFile: Validate writer files", + config.isRcfileWriterValidate(), + false), + booleanSessionProperty( + STATISTICS_ENABLED, + "Experimental: Expose table statistics", true, false)); } @@ -154,6 +160,11 @@ public static DataSize getOrcStreamBufferSize(ConnectorSession session) return session.getProperty(ORC_STREAM_BUFFER_SIZE, DataSize.class); } + public static DataSize getOrcMaxReadBlockSize(ConnectorSession session) + { + return session.getProperty(ORC_MAX_READ_BLOCK_SIZE, DataSize.class); + } + public static boolean isParquetPredicatePushdownEnabled(ConnectorSession session) { return session.getProperty(PARQUET_PREDICATE_PUSHDOWN_ENABLED, Boolean.class); @@ -169,11 +180,6 @@ public static DataSize getMaxInitialSplitSize(ConnectorSession session) return session.getProperty(MAX_INITIAL_SPLIT_SIZE, DataSize.class); } - public static boolean isRcfileOptimizedReaderEnabled(ConnectorSession session) - { - return session.getProperty(RCFILE_OPTIMIZED_READER_ENABLED, Boolean.class); - } - public static boolean isRcfileOptimizedWriterEnabled(ConnectorSession session) { return session.getProperty(RCFILE_OPTIMIZED_WRITER_ENABLED, Boolean.class); @@ -184,6 +190,11 @@ public static boolean isRcfileOptimizedWriterValidate(ConnectorSession session) return session.getProperty(RCFILE_OPTIMIZED_WRITER_VALIDATE, Boolean.class); } + public static boolean isStatisticsEnabled(ConnectorSession session) + { + return session.getProperty(STATISTICS_ENABLED, Boolean.class); + } + public static PropertyMetadata dataSizeSessionProperty(String name, String description, DataSize defaultValue, boolean hidden) { return new PropertyMetadata<>( diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java index d6bf69f143c80..e01095ada28f6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java @@ -26,7 +26,6 @@ import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; -import com.google.common.base.Verify; import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -51,7 +50,6 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_METASTORE_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_SCHEMA_MISMATCH; import static com.facebook.presto.hive.HivePartition.UNPARTITIONED_ID; -import static com.facebook.presto.hive.HiveUtil.checkCondition; import static com.facebook.presto.hive.metastore.MetastoreUtil.makePartName; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.SERVER_SHUTTING_DOWN; @@ -206,12 +204,11 @@ private Iterable getPartitionMetadata(SemiTransactionalHi ImmutableMap.Builder partitionBuilder = ImmutableMap.builder(); for (Map.Entry> entry : batch.entrySet()) { if (!entry.getValue().isPresent()) { - throw new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available"); + throw new PrestoException(HIVE_METASTORE_ERROR, "Partition no longer exists: " + entry.getKey()); } partitionBuilder.put(entry.getKey(), entry.getValue().get()); } Map partitions = partitionBuilder.build(); - Verify.verify(partitions.size() == partitionBatch.size()); if (partitionBatch.size() != partitions.size()) { throw new PrestoException(GENERIC_INTERNAL_ERROR, format("Expected %s partitions but found %s", partitionBatch.size(), partitions.size())); } @@ -266,15 +263,25 @@ private Iterable getPartitionMetadata(SemiTransactionalHi } } - Optional partitionBucketProperty = partition.getStorage().getBucketProperty(); - checkCondition( - partitionBucketProperty.equals(bucketProperty), - HiveErrorCode.HIVE_PARTITION_SCHEMA_MISMATCH, - "Hive table (%s) bucketing property (%s) does not match partition (%s) bucketing property (%s)", - hivePartition.getTableName(), - bucketProperty, - hivePartition.getPartitionId(), - partitionBucketProperty); + if (bucketProperty.isPresent()) { + Optional partitionBucketProperty = partition.getStorage().getBucketProperty(); + if (!partitionBucketProperty.isPresent()) { + throw new PrestoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( + "Hive table (%s) is bucketed but partition (%s) is not bucketed", + hivePartition.getTableName(), + hivePartition.getPartitionId())); + } + if (!bucketProperty.equals(partitionBucketProperty)) { + throw new PrestoException(HIVE_PARTITION_SCHEMA_MISMATCH, format( + "Hive table (%s) bucketing (columns=%s, buckets=%s) does not match partition (%s) bucketing (columns=%s, buckets=%s)", + hivePartition.getTableName(), + bucketProperty.get().getBucketedBy(), + bucketProperty.get().getBucketCount(), + hivePartition.getPartitionId(), + partitionBucketProperty.get().getBucketedBy(), + partitionBucketProperty.get().getBucketCount())); + } + } results.add(new HivePartitionMetadata(hivePartition, Optional.of(partition), columnCoercions.build())); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java index 7feff178cee51..f6c7ab02b9177 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java @@ -28,8 +28,11 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; import com.google.common.base.Joiner; +import com.google.common.base.Splitter; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; +import io.airlift.compress.lzo.LzoCodec; +import io.airlift.compress.lzo.LzopCodec; import io.airlift.slice.Slice; import io.airlift.slice.SliceUtf8; import io.airlift.slice.Slices; @@ -89,11 +92,12 @@ import static com.facebook.presto.hive.HivePartitionKey.HIVE_DEFAULT_DYNAMIC_PARTITION; import static com.facebook.presto.hive.RetryDriver.retry; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.Chars.isCharType; -import static com.facebook.presto.spi.type.Chars.trimSpaces; +import static com.facebook.presto.spi.type.Chars.trimTrailingSpaces; import static com.facebook.presto.spi.type.DateType.DATE; import static com.facebook.presto.spi.type.DecimalType.createDecimalType; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; @@ -104,6 +108,7 @@ import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.filter; +import static com.google.common.collect.Lists.newArrayList; import static com.google.common.collect.Lists.transform; import static java.lang.Byte.parseByte; import static java.lang.Double.parseDouble; @@ -116,6 +121,7 @@ import static java.math.BigDecimal.ROUND_UNNECESSARY; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; import static org.apache.hadoop.hive.common.FileUtils.unescapePathName; import static org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.FILE_INPUT_FORMAT; import static org.apache.hadoop.hive.serde.serdeConstants.DECIMAL_TYPE_NAME; @@ -167,7 +173,7 @@ private HiveUtil() setReadColumns(configuration, readHiveColumnIndexes); InputFormat inputFormat = getInputFormat(configuration, schema, true); - JobConf jobConf = new JobConf(configuration); + JobConf jobConf = toJobConf(configuration); FileSplit fileSplit = new FileSplit(path, start, length, (String[]) null); // propagate serialization configuration to getRecordReader @@ -175,6 +181,16 @@ private HiveUtil() .filter(name -> name.startsWith("serialization.")) .forEach(name -> jobConf.set(name, schema.getProperty(name))); + // add Airlift LZO and LZOP to head of codecs list so as to not override existing entries + List codecs = newArrayList(Splitter.on(",").trimResults().omitEmptyStrings().split(jobConf.get("io.compression.codecs", ""))); + if (!codecs.contains(LzoCodec.class.getName())) { + codecs.add(0, LzoCodec.class.getName()); + } + if (!codecs.contains(LzopCodec.class.getName())) { + codecs.add(0, LzopCodec.class.getName()); + } + jobConf.set("io.compression.codecs", codecs.stream().collect(joining(","))); + try { return retry() .stopOnIllegalExceptions() @@ -201,7 +217,7 @@ public static void setReadColumns(Configuration configuration, List rea { String inputFormatName = getInputFormatName(schema); try { - JobConf jobConf = new JobConf(configuration); + JobConf jobConf = toJobConf(configuration); Class> inputFormatClass = getInputFormatClass(jobConf, inputFormatName); if (symlinkTarget && (inputFormatClass == SymlinkTextInputFormat.class)) { @@ -688,7 +704,7 @@ public static Slice varcharPartitionKey(String value, String name, Type columnTy public static Slice charPartitionKey(String value, String name, Type columnType) { - Slice partitionKey = trimSpaces(Slices.utf8Slice(value)); + Slice partitionKey = trimTrailingSpaces(Slices.utf8Slice(value)); CharType charType = (CharType) columnType; if (SliceUtf8.countCodePoints(partitionKey) > charType.getLength()) { throw new PrestoException(HIVE_INVALID_PARTITION_VALUE, format("Invalid partition value '%s' for %s partition key: %s", value, columnType.toString(), name)); @@ -753,11 +769,6 @@ public static List getPartitionKeyColumnHandles(String connect return columns.build(); } - public static Slice base64Decode(byte[] bytes) - { - return Slices.wrappedBuffer(Base64.getDecoder().decode(bytes)); - } - public static void checkCondition(boolean condition, ErrorCodeSupplier errorCode, String formatString, Object... args) { if (!condition) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java index ee7dd77ed1332..4b5dc72867c75 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java @@ -96,15 +96,14 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_DATABASE_LOCATION_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_FILESYSTEM_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_WRITER_DATA_ERROR; -import static com.facebook.presto.hive.HiveSplitManager.PRESTO_OFFLINE; import static com.facebook.presto.hive.HiveUtil.checkCondition; import static com.facebook.presto.hive.HiveUtil.isArrayType; import static com.facebook.presto.hive.HiveUtil.isMapType; import static com.facebook.presto.hive.HiveUtil.isRowType; import static com.facebook.presto.hive.metastore.MetastoreUtil.getProtectMode; +import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyOnline; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.type.Chars.isCharType; -import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.base.Strings.padEnd; import static java.lang.Float.intBitsToFloat; import static java.lang.Math.toIntExact; @@ -367,20 +366,7 @@ private static void checkWritable( } // verify online - if (protectMode.offline) { - if (partitionName.isPresent()) { - throw new PartitionOfflineException(tableName, partitionName.get(), false, null); - } - throw new TableOfflineException(tableName, false, null); - } - - String prestoOffline = parameters.get(PRESTO_OFFLINE); - if (!isNullOrEmpty(prestoOffline)) { - if (partitionName.isPresent()) { - throw new PartitionOfflineException(tableName, partitionName.get(), true, prestoOffline); - } - throw new TableOfflineException(tableName, true, prestoOffline); - } + verifyOnline(tableName, partitionName, protectMode, parameters); // verify not read only if (protectMode.readOnly) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java index 910ea7333e940..8e60fc63a6ea0 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java @@ -63,6 +63,7 @@ import static com.facebook.presto.hive.HiveWriteUtils.getField; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; import static com.facebook.presto.hive.metastore.StorageFormat.fromHiveStorageFormat; +import static com.facebook.presto.hive.util.ConfigurationUtils.toJobConf; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -201,7 +202,7 @@ public HiveWriterFactory( entry -> session.getProperty(entry.getName(), entry.getJavaType()).toString())); Configuration conf = hdfsEnvironment.getConfiguration(writePath); - this.conf = new JobConf(conf); + this.conf = toJobConf(conf); // make sure the FileSystem is created with the correct Configuration object try { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/NumberParser.java b/presto-hive/src/main/java/com/facebook/presto/hive/NumberParser.java deleted file mode 100644 index 9c359bf8f4ec6..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/NumberParser.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.PrestoException; - -import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA; - -public final class NumberParser -{ - private NumberParser() {} - - public static long parseLong(byte[] bytes, int start, int length) - { - int limit = start + length; - - int sign = bytes[start] == '-' ? -1 : 1; - - if (sign == -1 || bytes[start] == '+') { - start++; - } - - long value = bytes[start] - ((int) '0'); - start++; - while (start < limit) { - value = value * 10 + (bytes[start] - ((int) '0')); - start++; - } - - return value * sign; - } - - public static float parseFloat(byte[] bytes, int start, int length) - { - String string = new String(bytes, 0, start, length); - try { - return Float.parseFloat(string); - } - catch (NumberFormatException e) { - throw new PrestoException(HIVE_BAD_DATA, e); - } - } - - public static double parseDouble(byte[] bytes, int start, int length) - { - String string = new String(bytes, 0, start, length); - try { - return Double.parseDouble(string); - } - catch (NumberFormatException e) { - throw new PrestoException(HIVE_BAD_DATA, e); - } - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/PartitionStatistics.java b/presto-hive/src/main/java/com/facebook/presto/hive/PartitionStatistics.java new file mode 100644 index 0000000000000..bfb7ec9ff698c --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/PartitionStatistics.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.hive; + +import com.facebook.presto.hive.metastore.HiveColumnStatistics; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; +import java.util.OptionalLong; + +import static java.util.Objects.requireNonNull; + +public class PartitionStatistics +{ + public static final PartitionStatistics EMPTY_STATISTICS = new PartitionStatistics( + false, + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.empty(), + ImmutableMap.of()); + + private final boolean columnStatsAcurate; + private final OptionalLong fileCount; + private final OptionalLong rowCount; + private final OptionalLong rawDataSize; + private final OptionalLong totalSize; + private final Map columnStatistics; + + public PartitionStatistics( + boolean columnStatsAcurate, + OptionalLong fileCount, + OptionalLong rowCount, + OptionalLong rawDataSize, + OptionalLong totalSize, + Map columnStatistics) + { + this.columnStatsAcurate = columnStatsAcurate; + this.fileCount = fileCount; + this.rowCount = rowCount; + this.rawDataSize = rawDataSize; + this.totalSize = totalSize; + this.columnStatistics = ImmutableMap.copyOf(requireNonNull(columnStatistics, "columnStatistics can not be null")); + } + + public boolean isColumnStatsAcurate() + { + return columnStatsAcurate; + } + + public OptionalLong getFileCount() + { + return fileCount; + } + + public OptionalLong getRowCount() + { + return rowCount; + } + + public OptionalLong getRawDataSize() + { + return rawDataSize; + } + + public OptionalLong getTotalSize() + { + return totalSize; + } + + public Map getColumnStatistics() + { + return columnStatistics; + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoHadoopConfiguration.java b/presto-hive/src/main/java/com/facebook/presto/hive/PrestoHadoopConfiguration.java deleted file mode 100644 index f65d619b34c7b..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoHadoopConfiguration.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.google.common.collect.ClassToInstanceMap; -import com.google.common.collect.ImmutableClassToInstanceMap; -import org.apache.hadoop.conf.Configuration; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public final class PrestoHadoopConfiguration - extends Configuration -{ - private final ClassToInstanceMap services; - - public PrestoHadoopConfiguration(ClassToInstanceMap services) - { - super(false); - this.services = ImmutableClassToInstanceMap.copyOf(requireNonNull(services, "services is null")); - } - - public T getService(Class type) - { - T service = services.getInstance(type); - checkArgument(service != null, "service not found: %s", type.getName()); - return service; - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java b/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java index a36e2f5afa99d..dd252ef757466 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/PrestoS3FileSystem.java @@ -83,8 +83,10 @@ import java.util.Date; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Optional; +import static com.amazonaws.services.s3.Headers.SERVER_SIDE_ENCRYPTION; import static com.amazonaws.services.s3.Headers.UNENCRYPTED_CONTENT_LENGTH; import static com.facebook.presto.hive.RetryDriver.retry; import static com.google.common.base.Preconditions.checkArgument; @@ -315,7 +317,7 @@ public FileStatus getFileStatus(Path path) } return new FileStatus( - getObjectSize(metadata), + getObjectSize(path, metadata), false, 1, BLOCK_SIZE.toBytes(), @@ -323,9 +325,14 @@ public FileStatus getFileStatus(Path path) qualifiedPath(path)); } - private static long getObjectSize(ObjectMetadata metadata) + private static long getObjectSize(Path path, ObjectMetadata metadata) + throws IOException { - String length = metadata.getUserMetadata().get(UNENCRYPTED_CONTENT_LENGTH); + Map userMetadata = metadata.getUserMetadata(); + String length = userMetadata.get(UNENCRYPTED_CONTENT_LENGTH); + if (userMetadata.containsKey(SERVER_SIDE_ENCRYPTION) && length == null) { + throw new IOException(format("%s header is not set on an encrypted object: %s", UNENCRYPTED_CONTENT_LENGTH, path)); + } return (length != null) ? Long.parseLong(length) : metadata.getContentLength(); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java index e56f17508fd31..8bbc267320401 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java @@ -57,27 +57,31 @@ public class RcFileFileWriterFactory private final HdfsEnvironment hdfsEnvironment; private final TypeManager typeManager; private final NodeVersion nodeVersion; + private final FileFormatDataSourceStats stats; @Inject public RcFileFileWriterFactory( HdfsEnvironment hdfsEnvironment, TypeManager typeManager, NodeVersion nodeVersion, - HiveClientConfig hiveClientConfig) + HiveClientConfig hiveClientConfig, + FileFormatDataSourceStats stats) { - this(hdfsEnvironment, typeManager, nodeVersion, requireNonNull(hiveClientConfig, "hiveClientConfig is null").getDateTimeZone()); + this(hdfsEnvironment, typeManager, nodeVersion, requireNonNull(hiveClientConfig, "hiveClientConfig is null").getDateTimeZone(), stats); } public RcFileFileWriterFactory( HdfsEnvironment hdfsEnvironment, TypeManager typeManager, NodeVersion nodeVersion, - DateTimeZone hiveStorageTimeZone) + DateTimeZone hiveStorageTimeZone, + FileFormatDataSourceStats stats) { this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); this.hiveStorageTimeZone = requireNonNull(hiveStorageTimeZone, "hiveStorageTimeZone is null"); + this.stats = requireNonNull(stats, "stats is null"); } @Override @@ -132,7 +136,8 @@ else if (ColumnarSerDe.class.getName().equals(storageFormat.getSerDe())) { return new HdfsRcFileDataSource( path.toString(), fileSystem.open(path), - fileSystem.getFileStatus(path).getLen()); + fileSystem.getFileStatus(path).getLen(), + stats); } catch (IOException e) { throw Throwables.propagate(e); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java b/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java index 3a37c831cbe46..4facd9771f38a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/StaticHiveCluster.java @@ -71,7 +71,7 @@ public HiveMetastoreClient createMetastoreClient() TTransportException lastException = null; for (HostAndPort metastore : metastores) { try { - return clientFactory.create(metastore.getHostText(), metastore.getPort()); + return clientFactory.create(metastore.getHost(), metastore.getPort()); } catch (TTransportException e) { lastException = e; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ThriftHiveMetastoreClient.java b/presto-hive/src/main/java/com/facebook/presto/hive/ThriftHiveMetastoreClient.java index b037f40b75a7a..4208acc3c8627 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ThriftHiveMetastoreClient.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ThriftHiveMetastoreClient.java @@ -14,15 +14,18 @@ package com.facebook.presto.hive; import com.facebook.presto.hive.metastore.HiveMetastoreClient; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.metastore.api.HiveObjectPrivilege; import org.apache.hadoop.hive.metastore.api.HiveObjectRef; import org.apache.hadoop.hive.metastore.api.Partition; +import org.apache.hadoop.hive.metastore.api.PartitionsStatsRequest; import org.apache.hadoop.hive.metastore.api.PrincipalPrivilegeSet; import org.apache.hadoop.hive.metastore.api.PrincipalType; import org.apache.hadoop.hive.metastore.api.PrivilegeBag; import org.apache.hadoop.hive.metastore.api.Role; import org.apache.hadoop.hive.metastore.api.Table; +import org.apache.hadoop.hive.metastore.api.TableStatsRequest; import org.apache.hadoop.hive.metastore.api.ThriftHiveMetastore; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; @@ -30,6 +33,7 @@ import org.apache.thrift.transport.TTransport; import java.util.List; +import java.util.Map; import static java.util.Objects.requireNonNull; @@ -134,6 +138,22 @@ public Table getTable(String databaseName, String tableName) return client.get_table(databaseName, tableName); } + @Override + public List getTableColumnStatistics(String databaseName, String tableName, List columnNames) + throws TException + { + TableStatsRequest tableStatsRequest = new TableStatsRequest(databaseName, tableName, columnNames); + return client.get_table_statistics_req(tableStatsRequest).getTableStats(); + } + + @Override + public Map> getPartitionColumnStatistics(String databaseName, String tableName, List columnNames, List partitionValues) + throws TException + { + PartitionsStatsRequest partitionsStatsRequest = new PartitionsStatsRequest(databaseName, tableName, columnNames, partitionValues); + return client.get_partitions_statistics_req(partitionsStatsRequest).getPartStats(); + } + @Override public List getPartitionNames(String databaseName, String tableName) throws TException diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java b/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java index 10288aa8629eb..c1b3591da07ca 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java @@ -14,11 +14,13 @@ package com.facebook.presto.hive; import io.airlift.event.client.EventField; +import io.airlift.event.client.EventField.EventFieldMapping; import io.airlift.event.client.EventType; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; +import java.time.Instant; import java.util.Map; import static java.util.Objects.requireNonNull; @@ -35,12 +37,13 @@ public class WriteCompletedEvent private final String storageFormat; private final String writerImplementation; private final String prestoVersion; - private final String serverAddress; + private final String host; private final String principal; private final String environment; private final Map sessionProperties; private final Long bytes; private final long rows; + private final Instant timestamp = Instant.now(); public WriteCompletedEvent( String queryId, @@ -66,7 +69,7 @@ public WriteCompletedEvent( this.storageFormat = requireNonNull(storageFormat, "storageFormat is null"); this.writerImplementation = requireNonNull(writerImplementation, "writerImplementation is null"); this.prestoVersion = requireNonNull(prestoVersion, "prestoVersion is null"); - this.serverAddress = requireNonNull(serverAddress, "serverAddress is null"); + this.host = requireNonNull(serverAddress, "serverAddress is null"); this.principal = principal; this.environment = requireNonNull(environment, "environment is null"); this.sessionProperties = requireNonNull(sessionProperties, "sessionProperties is null"); @@ -123,10 +126,10 @@ public String getPrestoVersion() return prestoVersion; } - @EventField - public String getServerAddress() + @EventField(fieldMapping = EventFieldMapping.HOST) + public String getHost() { - return serverAddress; + return host; } @Nullable @@ -160,4 +163,10 @@ public long getRows() { return rows; } + + @EventField(fieldMapping = EventFieldMapping.TIMESTAMP) + public Instant getTimestamp() + { + return timestamp; + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java index 285bf293cfa71..0059c118b6d54 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java @@ -15,7 +15,6 @@ import com.facebook.presto.hive.ForHdfs; import com.facebook.presto.hive.ForHiveMetastore; -import com.facebook.presto.hive.HiveClientConfig; import com.google.inject.Binder; import com.google.inject.Key; import com.google.inject.Module; @@ -25,6 +24,7 @@ import javax.inject.Inject; import static com.google.inject.Scopes.SINGLETON; +import static io.airlift.configuration.ConfigBinder.configBinder; public final class AuthenticationModules { @@ -48,15 +48,16 @@ public void configure(Binder binder) binder.bind(HiveMetastoreAuthentication.class) .to(KerberosHiveMetastoreAuthentication.class) .in(SINGLETON); + configBinder(binder).bindConfig(MetastoreKerberosConfig.class); } @Provides @Singleton @ForHiveMetastore - HadoopAuthentication createHadoopAuthentication(HiveClientConfig hiveClientConfig) + HadoopAuthentication createHadoopAuthentication(MetastoreKerberosConfig config) { - String principal = hiveClientConfig.getHiveMetastoreClientPrincipal(); - String keytabLocation = hiveClientConfig.getHiveMetastoreClientKeytab(); + String principal = config.getHiveMetastoreClientPrincipal(); + String keytabLocation = config.getHiveMetastoreClientKeytab(); return createCachingKerberosHadoopAuthentication(principal, keytabLocation); } }; @@ -91,16 +92,17 @@ public void configure(Binder binder) binder.bind(HdfsAuthentication.class) .to(DirectHdfsAuthentication.class) .in(SINGLETON); + configBinder(binder).bindConfig(HdfsKerberosConfig.class); } @Inject @Provides @Singleton @ForHdfs - HadoopAuthentication createHadoopAuthentication(HiveClientConfig hiveClientConfig) + HadoopAuthentication createHadoopAuthentication(HdfsKerberosConfig config) { - String principal = hiveClientConfig.getHdfsPrestoPrincipal(); - String keytabLocation = hiveClientConfig.getHdfsPrestoKeytab(); + String principal = config.getHdfsPrestoPrincipal(); + String keytabLocation = config.getHdfsPrestoKeytab(); return createCachingKerberosHadoopAuthentication(principal, keytabLocation); } }; @@ -116,16 +118,17 @@ public void configure(Binder binder) binder.bind(HdfsAuthentication.class) .to(ImpersonatingHdfsAuthentication.class) .in(SINGLETON); + configBinder(binder).bindConfig(HdfsKerberosConfig.class); } @Inject @Provides @Singleton @ForHdfs - HadoopAuthentication createHadoopAuthentication(HiveClientConfig hiveClientConfig) + HadoopAuthentication createHadoopAuthentication(HdfsKerberosConfig config) { - String principal = hiveClientConfig.getHdfsPrestoPrincipal(); - String keytabLocation = hiveClientConfig.getHdfsPrestoKeytab(); + String principal = config.getHdfsPrestoPrincipal(); + String keytabLocation = config.getHdfsPrestoKeytab(); return createCachingKerberosHadoopAuthentication(principal, keytabLocation); } }; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/HdfsKerberosConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/HdfsKerberosConfig.java new file mode 100644 index 0000000000000..23938da6a6bc4 --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/HdfsKerberosConfig.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.authentication; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; + +import javax.validation.constraints.NotNull; + +public class HdfsKerberosConfig +{ + private String hdfsPrestoPrincipal; + private String hdfsPrestoKeytab; + + @NotNull + public String getHdfsPrestoPrincipal() + { + return hdfsPrestoPrincipal; + } + + @Config("hive.hdfs.presto.principal") + @ConfigDescription("Presto principal used to access HDFS") + public HdfsKerberosConfig setHdfsPrestoPrincipal(String hdfsPrestoPrincipal) + { + this.hdfsPrestoPrincipal = hdfsPrestoPrincipal; + return this; + } + + @NotNull + public String getHdfsPrestoKeytab() + { + return hdfsPrestoKeytab; + } + + @Config("hive.hdfs.presto.keytab") + @ConfigDescription("Presto keytab used to access HDFS") + public HdfsKerberosConfig setHdfsPrestoKeytab(String hdfsPrestoKeytab) + { + this.hdfsPrestoKeytab = hdfsPrestoKeytab; + return this; + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java index 41d4d79326fde..4b15870f57825 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java @@ -14,7 +14,6 @@ package com.facebook.presto.hive.authentication; import com.facebook.presto.hive.ForHiveMetastore; -import com.facebook.presto.hive.HiveClientConfig; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import org.apache.hadoop.hive.thrift.client.TUGIAssumingTransport; @@ -41,9 +40,9 @@ public class KerberosHiveMetastoreAuthentication private final HadoopAuthentication authentication; @Inject - public KerberosHiveMetastoreAuthentication(HiveClientConfig hiveClientConfig, @ForHiveMetastore HadoopAuthentication authentication) + public KerberosHiveMetastoreAuthentication(MetastoreKerberosConfig config, @ForHiveMetastore HadoopAuthentication authentication) { - this(hiveClientConfig.getHiveMetastoreServicePrincipal(), authentication); + this(config.getHiveMetastoreServicePrincipal(), authentication); } public KerberosHiveMetastoreAuthentication(String hiveMetastoreServicePrincipal, HadoopAuthentication authentication) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/MetastoreKerberosConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/MetastoreKerberosConfig.java new file mode 100644 index 0000000000000..33b64a36966d0 --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/MetastoreKerberosConfig.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.authentication; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; + +import javax.validation.constraints.NotNull; + +public class MetastoreKerberosConfig +{ + private String hiveMetastoreServicePrincipal; + private String hiveMetastoreClientPrincipal; + private String hiveMetastoreClientKeytab; + + @NotNull + public String getHiveMetastoreServicePrincipal() + { + return hiveMetastoreServicePrincipal; + } + + @Config("hive.metastore.service.principal") + @ConfigDescription("Hive Metastore service principal") + public MetastoreKerberosConfig setHiveMetastoreServicePrincipal(String hiveMetastoreServicePrincipal) + { + this.hiveMetastoreServicePrincipal = hiveMetastoreServicePrincipal; + return this; + } + + @NotNull + public String getHiveMetastoreClientPrincipal() + { + return hiveMetastoreClientPrincipal; + } + + @Config("hive.metastore.client.principal") + @ConfigDescription("Hive Metastore client principal") + public MetastoreKerberosConfig setHiveMetastoreClientPrincipal(String hiveMetastoreClientPrincipal) + { + this.hiveMetastoreClientPrincipal = hiveMetastoreClientPrincipal; + return this; + } + + @NotNull + public String getHiveMetastoreClientKeytab() + { + return hiveMetastoreClientKeytab; + } + + @Config("hive.metastore.client.keytab") + @ConfigDescription("Hive Metastore client keytab location") + public MetastoreKerberosConfig setHiveMetastoreClientKeytab(String hiveMetastoreClientKeytab) + { + this.hiveMetastoreClientKeytab = hiveMetastoreClientKeytab; + return this; + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java index 7caed88eabe68..c43c2f0b85bb0 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/BridgingHiveMetastore.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; import com.google.common.collect.ImmutableMap; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.PrivilegeGrantInfo; @@ -73,6 +74,33 @@ public Optional
getTable(String databaseName, String tableName) return delegate.getTable(databaseName, tableName).map(MetastoreUtil::fromMetastoreApiTable); } + @Override + public Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames) + { + return delegate.getTableColumnStatistics(databaseName, tableName, columnNames).map(this::groupStatisticsByColumn); + } + + @Override + public Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames) + { + return delegate.getPartitionColumnStatistics(databaseName, tableName, partitionNames, columnNames).map( + statistics -> ImmutableMap.copyOf( + statistics.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> groupStatisticsByColumn(entry.getValue()) + )))); + } + + private Map groupStatisticsByColumn(Set statistics) + { + return ImmutableMap.copyOf( + statistics.stream() + .collect(Collectors.toMap( + ColumnStatisticsObj::getColName, + MetastoreUtil::fromMetastoreApiColumnStatistics))); + } + @Override public Optional> getAllTables(String databaseName) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java index ed89916130d6f..f72889f35b8c8 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/CachingHiveMetastore.java @@ -41,15 +41,19 @@ import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; import static com.facebook.presto.hive.HiveUtil.toPartitionValues; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.cache.CacheLoader.asyncReloading; import static com.google.common.collect.Iterables.transform; +import static com.google.common.collect.Streams.stream; import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toMap; /** * Hive Metastore Cache @@ -63,6 +67,8 @@ public class CachingHiveMetastore private final LoadingCache> databaseNamesCache; private final LoadingCache> tableCache; private final LoadingCache>> tableNamesCache; + private final LoadingCache> tableColumnStatisticsCache; + private final LoadingCache> partitionColumnStatisticsCache; private final LoadingCache>> viewNamesCache; private final LoadingCache> partitionCache; private final LoadingCache>> partitionFilterCache; @@ -139,6 +145,42 @@ public Optional> load(String databaseName) } }, executor)); + tableColumnStatisticsCache = newCacheBuilder(expiresAfterWriteMillis, refreshMills, maximumSize) + .build(asyncReloading(new CacheLoader>() + { + @Override + public Optional load(TableColumnStatisticsCacheKey key) + throws Exception + { + return loadAll(ImmutableList.of(key)).get(key); + } + + @Override + public Map> loadAll(Iterable keys) + throws Exception + { + return loadColumnStatistics(keys); + } + }, executor)); + + partitionColumnStatisticsCache = newCacheBuilder(expiresAfterWriteMillis, refreshMills, maximumSize) + .build(asyncReloading(new CacheLoader>() + { + @Override + public Optional load(PartitionColumnStatisticsCacheKey key) + throws Exception + { + return loadAll(ImmutableList.of(key)).get(key); + } + + @Override + public Map> loadAll(Iterable keys) + throws Exception + { + return loadPartitionColumnStatistics(keys); + } + }, executor)); + tableCache = newCacheBuilder(expiresAfterWriteMillis, refreshMills, maximumSize) .build(asyncReloading(new CacheLoader>() { @@ -294,6 +336,103 @@ private Optional
loadTable(HiveTableName hiveTableName) return delegate.getTable(hiveTableName.getDatabaseName(), hiveTableName.getTableName()); } + @Override + public Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames) + { + Map> cacheValues = + getAll(tableColumnStatisticsCache, columnNames.stream() + .map(columnName -> new TableColumnStatisticsCacheKey(databaseName, tableName, columnName)) + .collect(toList())); + + return Optional.of( + ImmutableMap.copyOf( + cacheValues.entrySet().stream() + .filter(entry -> entry.getValue().isPresent()) + .collect(toMap( + entry -> entry.getKey().getColumnName(), + entry -> entry.getValue().get())))); + } + + private Map> loadColumnStatistics(Iterable keys) + { + if (Iterables.isEmpty(keys)) { + return ImmutableMap.of(); + } + + HiveTableName hiveTableName = stream(keys).findFirst().get().getHiveTableName(); + checkArgument(stream(keys).allMatch(key -> key.getHiveTableName().equals(hiveTableName)), "all keys must relate to same hive table"); + + Set columnNames = stream(keys).map(TableColumnStatisticsCacheKey::getColumnName).collect(Collectors.toSet()); + + Optional> columnStatistics = delegate.getTableColumnStatistics(hiveTableName.getDatabaseName(), hiveTableName.getTableName(), columnNames); + + ImmutableMap.Builder> resultMap = ImmutableMap.builder(); + for (TableColumnStatisticsCacheKey key : keys) { + if (!columnStatistics.isPresent() || !columnStatistics.get().containsKey(key.getColumnName())) { + resultMap.put(key, Optional.empty()); + } + else { + resultMap.put(key, Optional.of(columnStatistics.get().get(key.getColumnName()))); + } + } + return resultMap.build(); + } + + @Override + public Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames) + { + List cacheKeys = partitionNames.stream() + .flatMap( + partitionName -> columnNames.stream().map( + columnName -> new PartitionColumnStatisticsCacheKey(databaseName, tableName, partitionName, columnName))) + .collect(toList()); + Map> cacheValues = getAll(partitionColumnStatisticsCache, cacheKeys); + + ImmutableMap.Builder> partitionsMap = ImmutableMap.builder(); + for (String partitionName : partitionNames) { + ImmutableMap.Builder columnsMap = ImmutableMap.builder(); + for (String columnName : columnNames) { + Optional cacheValue = cacheValues.get(new PartitionColumnStatisticsCacheKey(databaseName, tableName, partitionName, columnName)); + if (cacheValue.isPresent()) { + columnsMap.put(columnName, cacheValue.get()); + } + } + partitionsMap.put(partitionName, columnsMap.build()); + } + return Optional.of(partitionsMap.build()); + } + + private Map> loadPartitionColumnStatistics(Iterable keys) + { + if (Iterables.isEmpty(keys)) { + return ImmutableMap.of(); + } + PartitionColumnStatisticsCacheKey firstKey = Iterables.getFirst(keys, null); + HiveTableName hiveTableName = firstKey.getHivePartitionName().getHiveTableName(); + checkArgument(stream(keys).allMatch(key -> key.getHivePartitionName().getHiveTableName().equals(hiveTableName)), "all keys must relate to same hive table"); + Set partitionNames = stream(keys).map(key -> key.getHivePartitionName().getPartitionName()).collect(Collectors.toSet()); + Set columnNames = stream(keys).map(PartitionColumnStatisticsCacheKey::getColumnName).collect(Collectors.toSet()); + + Optional>> columnStatistics = delegate.getPartitionColumnStatistics( + hiveTableName.getDatabaseName(), + hiveTableName.getTableName(), + partitionNames, + columnNames); + + ImmutableMap.Builder> resultMap = ImmutableMap.builder(); + for (PartitionColumnStatisticsCacheKey key : keys) { + if (columnStatistics.isPresent() + && columnStatistics.get().containsKey(key.getHivePartitionName().getPartitionName()) + && columnStatistics.get().get(key.getHivePartitionName().getPartitionName()).containsKey(key.getColumnName())) { + resultMap.put(key, Optional.of(columnStatistics.get().get(key.getHivePartitionName().getPartitionName()).get(key.getColumnName()))); + } + else { + resultMap.put(key, Optional.empty()); + } + } + return resultMap.build(); + } + @Override public Optional> getAllTables(String databaseName) { @@ -882,4 +1021,93 @@ public String toString() .toString(); } } + + private static final class TableColumnStatisticsCacheKey + { + private final HiveTableName hiveTableName; + private final String columnName; + + public TableColumnStatisticsCacheKey(String databaseName, String tableName, String columnName) + { + this.hiveTableName = HiveTableName.table( + requireNonNull(databaseName, "databaseName is null"), + requireNonNull(tableName, "tableName can not be null")); + this.columnName = requireNonNull(columnName, "columnName can not be null"); + } + + public HiveTableName getHiveTableName() + { + return hiveTableName; + } + + public String getColumnName() + { + return columnName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TableColumnStatisticsCacheKey that = (TableColumnStatisticsCacheKey) o; + return Objects.equals(hiveTableName, that.hiveTableName) && + Objects.equals(columnName, that.columnName); + } + + @Override + public int hashCode() + { + return Objects.hash(hiveTableName, columnName); + } + } + + private static final class PartitionColumnStatisticsCacheKey + { + private final HivePartitionName hivePartitionName; + private final String columnName; + + public PartitionColumnStatisticsCacheKey(String databaseName, String tableName, String partitionName, String columnName) + { + this.hivePartitionName = HivePartitionName.partition( + requireNonNull(databaseName, "databaseName is null"), + requireNonNull(tableName, "tableName can not be null"), + requireNonNull(partitionName, "partitionName can not be null")); + this.columnName = requireNonNull(columnName, "columnName can not be null"); + } + + public HivePartitionName getHivePartitionName() + { + return hivePartitionName; + } + + public String getColumnName() + { + return columnName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartitionColumnStatisticsCacheKey that = (PartitionColumnStatisticsCacheKey) o; + return Objects.equals(hivePartitionName, that.hivePartitionName) && + Objects.equals(columnName, that.columnName); + } + + @Override + public int hashCode() + { + return Objects.hash(hivePartitionName, columnName); + } + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java index 455170702e533..c130938b05243 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java @@ -28,6 +28,10 @@ public interface ExtendedHiveMetastore Optional
getTable(String databaseName, String tableName); + Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames); + + Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames); + Optional> getAllTables(String databaseName); Optional> getAllViews(String databaseName); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveColumnStatistics.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveColumnStatistics.java new file mode 100644 index 0000000000000..e3eae1fd643ed --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveColumnStatistics.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.hive.metastore; + +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalLong; + +public class HiveColumnStatistics +{ + private final Optional lowValue; + private final Optional highValue; + private final OptionalLong maxColumnLength; + private final OptionalDouble averageColumnLength; + private final OptionalLong trueCount; + private final OptionalLong falseCount; + private final OptionalLong nullsCount; + private final OptionalLong distinctValuesCount; + + public HiveColumnStatistics( + Optional lowValue, + Optional highValue, + OptionalLong maxColumnLength, + OptionalDouble averageColumnLength, + OptionalLong trueCount, + OptionalLong falseCount, + OptionalLong nullsCount, + OptionalLong distinctValuesCount) + { + this.lowValue = lowValue; + this.highValue = highValue; + this.maxColumnLength = maxColumnLength; + this.averageColumnLength = averageColumnLength; + this.trueCount = trueCount; + this.falseCount = falseCount; + this.nullsCount = nullsCount; + this.distinctValuesCount = distinctValuesCount; + } + + public Optional getLowValue() + { + return lowValue; + } + + public Optional getHighValue() + { + return highValue; + } + + public OptionalLong getMaxColumnLength() + { + return maxColumnLength; + } + + public OptionalDouble getAverageColumnLength() + { + return averageColumnLength; + } + + public OptionalLong getTrueCount() + { + return trueCount; + } + + public OptionalLong getFalseCount() + { + return falseCount; + } + + public OptionalLong getNullsCount() + { + return nullsCount; + } + + public OptionalLong getDistinctValuesCount() + { + return distinctValuesCount; + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastore.java index 7f849c3cd329e..7e439c6fbeb7f 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastore.java @@ -13,12 +13,14 @@ */ package com.facebook.presto.hive.metastore; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.metastore.api.Partition; import org.apache.hadoop.hive.metastore.api.PrivilegeGrantInfo; import org.apache.hadoop.hive.metastore.api.Table; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -69,6 +71,10 @@ public interface HiveMetastore Optional
getTable(String databaseName, String tableName); + Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames); + + Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames); + Set getRoles(String user); Set getDatabasePrivileges(String user, String databaseName); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastoreClient.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastoreClient.java index 0b808f477db73..312d4616a1582 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastoreClient.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/HiveMetastoreClient.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.metastore; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.metastore.api.HiveObjectPrivilege; import org.apache.hadoop.hive.metastore.api.HiveObjectRef; @@ -26,6 +27,7 @@ import java.io.Closeable; import java.util.List; +import java.util.Map; public interface HiveMetastoreClient extends Closeable @@ -66,6 +68,12 @@ void alterTable(String databaseName, String tableName, Table newTable) Table getTable(String databaseName, String tableName) throws TException; + List getTableColumnStatistics(String databaseName, String tableName, List columnNames) + throws TException; + + Map> getPartitionColumnStatistics(String databaseName, String tableName, List columnNames, List partitionValues) + throws TException; + List getPartitionNames(String databaseName, String tableName) throws TException; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java index 060026802c8a8..7490414009c13 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java @@ -15,29 +15,51 @@ import com.facebook.presto.hive.HiveBucketProperty; import com.facebook.presto.hive.HiveType; +import com.facebook.presto.hive.PartitionOfflineException; +import com.facebook.presto.hive.TableOfflineException; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.apache.hadoop.hive.common.FileUtils; import org.apache.hadoop.hive.metastore.ProtectMode; +import org.apache.hadoop.hive.metastore.api.BinaryColumnStatsData; +import org.apache.hadoop.hive.metastore.api.BooleanColumnStatsData; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; +import org.apache.hadoop.hive.metastore.api.Date; +import org.apache.hadoop.hive.metastore.api.DateColumnStatsData; +import org.apache.hadoop.hive.metastore.api.Decimal; +import org.apache.hadoop.hive.metastore.api.DecimalColumnStatsData; +import org.apache.hadoop.hive.metastore.api.DoubleColumnStatsData; import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.metastore.api.LongColumnStatsData; import org.apache.hadoop.hive.metastore.api.PrincipalPrivilegeSet; import org.apache.hadoop.hive.metastore.api.PrivilegeGrantInfo; import org.apache.hadoop.hive.metastore.api.SerDeInfo; import org.apache.hadoop.hive.metastore.api.StorageDescriptor; +import org.apache.hadoop.hive.metastore.api.StringColumnStatsData; +import javax.annotation.Nullable; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.time.LocalDate; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalLong; import java.util.Properties; import java.util.Set; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_METADATA; +import static com.facebook.presto.hive.HiveSplitManager.PRESTO_OFFLINE; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.parsePrivilege; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.emptyToNull; +import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.base.Strings.nullToEmpty; import static java.lang.String.format; import static java.util.stream.Collectors.toList; @@ -343,6 +365,113 @@ public static Partition fromMetastoreApiPartition(org.apache.hadoop.hive.metasto return partitionBuilder.build(); } + public static HiveColumnStatistics fromMetastoreApiColumnStatistics(ColumnStatisticsObj columnStatistics) + { + if (columnStatistics.getStatsData().isSetLongStats()) { + LongColumnStatsData longStatsData = columnStatistics.getStatsData().getLongStats(); + return new HiveColumnStatistics<>( + longStatsData.isSetLowValue() ? Optional.of(longStatsData.getLowValue()) : Optional.empty(), + longStatsData.isSetHighValue() ? Optional.of(longStatsData.getHighValue()) : Optional.empty(), + OptionalLong.empty(), + OptionalDouble.empty(), + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.of(longStatsData.getNumNulls()), + OptionalLong.of(longStatsData.getNumDVs())); + } + else if (columnStatistics.getStatsData().isSetDoubleStats()) { + DoubleColumnStatsData doubleStatsData = columnStatistics.getStatsData().getDoubleStats(); + return new HiveColumnStatistics<>( + doubleStatsData.isSetLowValue() ? Optional.of(doubleStatsData.getLowValue()) : Optional.empty(), + doubleStatsData.isSetHighValue() ? Optional.of(doubleStatsData.getHighValue()) : Optional.empty(), + OptionalLong.empty(), + OptionalDouble.empty(), + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.of(doubleStatsData.getNumNulls()), + OptionalLong.of(doubleStatsData.getNumDVs())); + } + else if (columnStatistics.getStatsData().isSetDecimalStats()) { + DecimalColumnStatsData decimalStatsData = columnStatistics.getStatsData().getDecimalStats(); + return new HiveColumnStatistics<>( + decimalStatsData.isSetLowValue() ? fromMetastoreDecimal(decimalStatsData.getLowValue()) : Optional.empty(), + decimalStatsData.isSetHighValue() ? fromMetastoreDecimal(decimalStatsData.getHighValue()) : Optional.empty(), + OptionalLong.empty(), + OptionalDouble.empty(), + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.of(decimalStatsData.getNumNulls()), + OptionalLong.of(decimalStatsData.getNumDVs())); + } + else if (columnStatistics.getStatsData().isSetBooleanStats()) { + BooleanColumnStatsData booleanStatsData = columnStatistics.getStatsData().getBooleanStats(); + return new HiveColumnStatistics<>( + Optional.empty(), + Optional.empty(), + OptionalLong.empty(), + OptionalDouble.empty(), + OptionalLong.of(booleanStatsData.getNumTrues()), + OptionalLong.of(booleanStatsData.getNumFalses()), + OptionalLong.of(booleanStatsData.getNumNulls()), + OptionalLong.of((booleanStatsData.getNumFalses() > 0 ? 1 : 0) + (booleanStatsData.getNumTrues() > 0 ? 1 : 0))); + } + else if (columnStatistics.getStatsData().isSetDateStats()) { + DateColumnStatsData dateStatsData = columnStatistics.getStatsData().getDateStats(); + return new HiveColumnStatistics<>( + dateStatsData.isSetLowValue() ? fromMetastoreDate(dateStatsData.getLowValue()) : Optional.empty(), + dateStatsData.isSetHighValue() ? fromMetastoreDate(dateStatsData.getHighValue()) : Optional.empty(), + OptionalLong.empty(), + OptionalDouble.empty(), + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.of(dateStatsData.getNumNulls()), + OptionalLong.of(dateStatsData.getNumDVs())); + } + else if (columnStatistics.getStatsData().isSetStringStats()) { + StringColumnStatsData stringStatsData = columnStatistics.getStatsData().getStringStats(); + return new HiveColumnStatistics<>( + Optional.empty(), + Optional.empty(), + OptionalLong.of(stringStatsData.getMaxColLen()), + OptionalDouble.of(stringStatsData.getAvgColLen()), + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.of(stringStatsData.getNumNulls()), + OptionalLong.of(stringStatsData.getNumDVs())); + } + else if (columnStatistics.getStatsData().isSetBinaryStats()) { + BinaryColumnStatsData binaryStatsData = columnStatistics.getStatsData().getBinaryStats(); + return new HiveColumnStatistics<>( + Optional.empty(), + Optional.empty(), + OptionalLong.of(binaryStatsData.getMaxColLen()), + OptionalDouble.of(binaryStatsData.getAvgColLen()), + OptionalLong.empty(), + OptionalLong.empty(), + OptionalLong.of(binaryStatsData.getNumNulls()), + OptionalLong.empty()); + } + else { + throw new PrestoException(HIVE_INVALID_METADATA, "Invalid column statistics data: " + columnStatistics); + } + } + + private static Optional fromMetastoreDate(Date date) + { + if (date == null) { + return Optional.empty(); + } + return Optional.of(LocalDate.ofEpochDay(date.getDaysSinceEpoch())); + } + + private static Optional fromMetastoreDecimal(@Nullable Decimal decimal) + { + if (decimal == null) { + return Optional.empty(); + } + return Optional.of(new BigDecimal(new BigInteger(decimal.getUnscaled()), decimal.getScale())); + } + private static PrincipalType fromMetastoreApiPrincipalType(org.apache.hadoop.hive.metastore.api.PrincipalType principalType) { switch (principalType) { @@ -455,4 +584,22 @@ private static void fromMetastoreApiStorageDescriptor(StorageDescriptor storageD .setSkewed(storageDescriptor.isSetSkewedInfo() && storageDescriptor.getSkewedInfo().isSetSkewedColNames() && !storageDescriptor.getSkewedInfo().getSkewedColNames().isEmpty()) .setSerdeParameters(serdeInfo.getParameters() == null ? ImmutableMap.of() : serdeInfo.getParameters()); } + + public static void verifyOnline(SchemaTableName tableName, Optional partitionName, ProtectMode protectMode, Map parameters) + { + if (protectMode.offline) { + if (partitionName.isPresent()) { + throw new PartitionOfflineException(tableName, partitionName.get(), false, null); + } + throw new TableOfflineException(tableName, false, null); + } + + String prestoOffline = parameters.get(PRESTO_OFFLINE); + if (!isNullOrEmpty(prestoOffline)) { + if (partitionName.isPresent()) { + throw new PartitionOfflineException(tableName, partitionName.get(), true, prestoOffline); + } + throw new TableOfflineException(tableName, true, prestoOffline); + } + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java index 02a5a5a3913d7..94ef07774db89 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java @@ -144,6 +144,65 @@ public synchronized Optional
getTable(String databaseName, String tableNa } } + public synchronized Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames) + { + checkReadable(); + Action tableAction = tableActions.get(new SchemaTableName(databaseName, tableName)); + if (tableAction == null) { + return delegate.getTableColumnStatistics(databaseName, tableName, columnNames); + } + switch (tableAction.getType()) { + case ADD: + case ALTER: + case INSERT_EXISTING: + case DROP: + return Optional.empty(); + default: + throw new IllegalStateException("Unknown action type"); + } + } + + public synchronized Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames) + { + checkReadable(); + Optional
table = getTable(databaseName, tableName); + if (!table.isPresent()) { + return Optional.empty(); + } + TableSource tableSource = getTableSource(databaseName, tableName); + Map, Action> partitionActionsOfTable = partitionActions.computeIfAbsent(new SchemaTableName(databaseName, tableName), k -> new HashMap<>()); + ImmutableSet.Builder partitionNamesToQuery = ImmutableSet.builder(); + ImmutableMap.Builder> resultBuilder = ImmutableMap.builder(); + for (String partitionName : partitionNames) { + List partitionValues = toPartitionValues(partitionName); + Action partitionAction = partitionActionsOfTable.get(partitionValues); + if (partitionAction == null) { + switch (tableSource) { + case PRE_EXISTING_TABLE: + partitionNamesToQuery.add(partitionName); + break; + case CREATED_IN_THIS_TRANSACTION: + resultBuilder.put(partitionName, ImmutableMap.of()); + break; + default: + throw new UnsupportedOperationException("unknown table source"); + } + } + else { + resultBuilder.put(partitionName, ImmutableMap.of()); + } + } + + Optional>> delegateResult = delegate.getPartitionColumnStatistics(databaseName, tableName, partitionNamesToQuery.build(), columnNames); + if (delegateResult.isPresent()) { + resultBuilder.putAll(delegateResult.get()); + } + else { + partitionNamesToQuery.build().forEach(partionName -> resultBuilder.put(partionName, ImmutableMap.of())); + } + return Optional.of(resultBuilder.build()); + } + /** * This method can only be called when the table is known to exist */ @@ -1562,6 +1621,11 @@ private static void renameDirectory(String user, HdfsEnvironment hdfsEnvironment } } + private static Optional getPrestoQueryId(Table table) + { + return Optional.ofNullable(table.getParameters().get(PRESTO_QUERY_ID_NAME)); + } + private static Optional getPrestoQueryId(Partition partition) { return Optional.ofNullable(partition.getParameters().get(PRESTO_QUERY_ID_NAME)); @@ -1976,7 +2040,9 @@ private static class CreateTableOperation public CreateTableOperation(Table table, PrincipalPrivileges privileges) { - this.table = requireNonNull(table, "table is null"); + requireNonNull(table, "table is null"); + checkArgument(getPrestoQueryId(table).isPresent()); + this.table = table; this.privileges = requireNonNull(privileges, "privileges is null"); } @@ -1987,8 +2053,28 @@ public String getDescription() public void run(ExtendedHiveMetastore metastore) { - metastore.createTable(table, privileges); - done = true; + try { + metastore.createTable(table, privileges); + done = true; + } + catch (RuntimeException e) { + try { + Optional
remoteTable = metastore.getTable(table.getDatabaseName(), table.getTableName()); + // getPrestoQueryId(partition) is guaranteed to be non-empty. It is asserted in the constructor. + if (remoteTable.isPresent() && getPrestoQueryId(remoteTable.get()).equals(getPrestoQueryId(table))) { + done = true; + } + } + catch (RuntimeException ignored) { + // When table could not be fetched from metastore, it is not known whether the table was added. + // Deleting the table when aborting commit has the risk of deleting table not added in this transaction. + // Not deleting the table may leave garbage behind. The former is much more dangerous than the latter. + // Therefore, the table is not considered added. + } + if (!done) { + throw e; + } + } } public void undo(ExtendedHiveMetastore metastore) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastore.java index 55c4a3d06054d..3a70bb9262fb8 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastore.java @@ -29,6 +29,7 @@ import com.google.common.collect.Iterables; import org.apache.hadoop.hive.metastore.TableType; import org.apache.hadoop.hive.metastore.api.AlreadyExistsException; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.metastore.api.HiveObjectPrivilege; import org.apache.hadoop.hive.metastore.api.HiveObjectRef; @@ -73,6 +74,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toSet; import static org.apache.hadoop.hive.metastore.api.HiveObjectType.DATABASE; import static org.apache.hadoop.hive.metastore.api.HiveObjectType.TABLE; @@ -220,6 +222,59 @@ public Optional
getTable(String databaseName, String tableName) } } + @Override + public Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames) + { + try { + return retry() + .stopOn(NoSuchObjectException.class, HiveViewNotSupportedException.class) + .stopOnIllegalExceptions() + .run("getTableColumnStatistics", stats.getGetTableColumnStatistics().wrap(() -> { + try (HiveMetastoreClient client = clientProvider.createMetastoreClient()) { + return Optional.of(ImmutableSet.copyOf(client.getTableColumnStatistics(databaseName, tableName, ImmutableList.copyOf(columnNames)))); + } + })); + } + catch (NoSuchObjectException e) { + return Optional.empty(); + } + catch (TException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, e); + } + catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @Override + public Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionValues, Set columnNames) + { + try { + return retry() + .stopOn(NoSuchObjectException.class, HiveViewNotSupportedException.class) + .stopOnIllegalExceptions() + .run("getPartitionColumnStatistics", stats.getGetPartitionColumnStatistics().wrap(() -> { + try (HiveMetastoreClient client = clientProvider.createMetastoreClient()) { + Map> partitionColumnStatistics = client.getPartitionColumnStatistics(databaseName, tableName, ImmutableList.copyOf(columnNames), ImmutableList.copyOf(partitionValues)); + return Optional.of(partitionColumnStatistics.entrySet() + .stream() + .collect(toMap( + Map.Entry::getKey, + entry -> ImmutableSet.copyOf(entry.getValue())))); + } + })); + } + catch (NoSuchObjectException e) { + return Optional.empty(); + } + catch (TException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, e); + } + catch (Exception e) { + throw Throwables.propagate(e); + } + } + @Override public Optional> getAllViews(String databaseName) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastoreStats.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastoreStats.java index 3f2f1fad746b2..3a66b0e9f4378 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastoreStats.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/ThriftHiveMetastoreStats.java @@ -23,6 +23,8 @@ public class ThriftHiveMetastoreStats private final HiveMetastoreApiStats getAllTables = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats getAllViews = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats getTable = new HiveMetastoreApiStats(); + private final HiveMetastoreApiStats getTableColumnStatistics = new HiveMetastoreApiStats(); + private final HiveMetastoreApiStats getPartitionColumnStatistics = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats getPartitionNames = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats getPartitionNamesPs = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats getPartition = new HiveMetastoreApiStats(); @@ -77,6 +79,16 @@ public HiveMetastoreApiStats getGetTable() return getTable; } + public HiveMetastoreApiStats getGetTableColumnStatistics() + { + return getTableColumnStatistics; + } + + public HiveMetastoreApiStats getGetPartitionColumnStatistics() + { + return getPartitionColumnStatistics; + } + @Managed @Nested public HiveMetastoreApiStats getGetPartitionNames() diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java index b1b3427e7ebf0..b6b7e3f95dde9 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java @@ -20,6 +20,7 @@ import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.Database; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; +import com.facebook.presto.hive.metastore.HiveColumnStatistics; import com.facebook.presto.hive.metastore.HivePrivilegeInfo; import com.facebook.presto.hive.metastore.Partition; import com.facebook.presto.hive.metastore.PrincipalPrivileges; @@ -249,6 +250,18 @@ public synchronized Optional
getTable(String databaseName, String tableNa .map(tableMetadata -> tableMetadata.toTable(databaseName, tableName, tableMetadataDirectory.toString())); } + @Override + public Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames) + { + return Optional.of(ImmutableMap.of()); + } + + @Override + public Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames) + { + return Optional.of(ImmutableMap.of()); + } + private Table getRequiredTable(String databaseName, String tableName) { return getTable(databaseName, tableName) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java index 5d3b801f5b297..9c51f56c76328 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfPageSourceFactory.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.orc; import com.facebook.hive.orc.OrcSerde; +import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.hive.HivePageSourceFactory; @@ -34,6 +35,7 @@ import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxBufferSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxMergeDistance; +import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxReadBlockSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcStreamBufferSize; import static com.facebook.presto.hive.HiveUtil.isDeserializerClass; import static com.facebook.presto.hive.orc.OrcPageSourceFactory.createOrcPageSource; @@ -44,12 +46,14 @@ public class DwrfPageSourceFactory { private final TypeManager typeManager; private final HdfsEnvironment hdfsEnvironment; + private final FileFormatDataSourceStats stats; @Inject - public DwrfPageSourceFactory(TypeManager typeManager, HdfsEnvironment hdfsEnvironment) + public DwrfPageSourceFactory(TypeManager typeManager, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); + this.stats = requireNonNull(stats, "stats is null"); } @Override @@ -83,6 +87,8 @@ public Optional createPageSource(Configuration co getOrcMaxMergeDistance(session), getOrcMaxBufferSize(session), getOrcStreamBufferSize(session), - false)); + getOrcMaxReadBlockSize(session), + false, + stats)); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java index e523b03a85059..f852397915462 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java @@ -13,7 +13,9 @@ */ package com.facebook.presto.hive.orc; +import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.orc.AbstractOrcDataSource; +import com.facebook.presto.orc.OrcDataSourceId; import com.facebook.presto.spi.PrestoException; import io.airlift.units.DataSize; import org.apache.hadoop.fs.FSDataInputStream; @@ -23,16 +25,26 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_MISSING_DATA; import static com.facebook.presto.hive.HiveErrorCode.HIVE_UNKNOWN_ERROR; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; public class HdfsOrcDataSource extends AbstractOrcDataSource { private final FSDataInputStream inputStream; + private final FileFormatDataSourceStats stats; - public HdfsOrcDataSource(String name, long size, DataSize maxMergeDistance, DataSize maxReadSize, DataSize streamBufferSize, FSDataInputStream inputStream) + public HdfsOrcDataSource( + OrcDataSourceId id, + long size, + DataSize maxMergeDistance, + DataSize maxReadSize, + DataSize streamBufferSize, + FSDataInputStream inputStream, + FileFormatDataSourceStats stats) { - super(name, size, maxMergeDistance, maxReadSize, streamBufferSize); - this.inputStream = inputStream; + super(id, size, maxMergeDistance, maxReadSize, streamBufferSize); + this.inputStream = requireNonNull(inputStream, "inputStream is null"); + this.stats = requireNonNull(stats, "stats is null"); } @Override @@ -47,7 +59,9 @@ protected void readInternal(long position, byte[] buffer, int bufferOffset, int throws IOException { try { + long readStart = System.nanoTime(); inputStream.readFully(position, buffer, bufferOffset, bufferLength); + stats.readDataBytesPerSecond(bufferLength, System.nanoTime() - readStart); } catch (PrestoException e) { // just in case there is a Presto wrapper or hook diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java index ea74549fc83e3..a09794c54cb45 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.orc; +import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.orc.OrcCorruptionException; import com.facebook.presto.orc.OrcDataSource; @@ -60,18 +61,23 @@ public class OrcPageSource private final AggregatedMemoryContext systemMemoryContext; + private final FileFormatDataSourceStats stats; + public OrcPageSource( OrcRecordReader recordReader, OrcDataSource orcDataSource, List columns, TypeManager typeManager, - AggregatedMemoryContext systemMemoryContext) + AggregatedMemoryContext systemMemoryContext, + FileFormatDataSourceStats stats) { this.recordReader = requireNonNull(recordReader, "recordReader is null"); this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null"); int size = requireNonNull(columns, "columns is null").size(); + this.stats = requireNonNull(stats, "stats is null"); + this.constantBlocks = new Block[size]; this.hiveColumnIndexes = new int[size]; @@ -145,7 +151,7 @@ public Page getNextPage() blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize); } else { - blocks[fieldId] = new LazyBlock(batchSize, new OrcBlockLoader(hiveColumnIndexes[fieldId], type)); + blocks[fieldId] = new LazyBlock(batchSize, new OrcBlockLoader(hiveColumnIndexes[fieldId], type, stats)); } } return new Page(batchSize, blocks); @@ -170,6 +176,7 @@ public void close() closed = true; try { + stats.addMaxCombinedBytesPerRow(recordReader.getMaxCombinedBytesPerRow()); recordReader.close(); } catch (IOException e) { @@ -212,12 +219,14 @@ private final class OrcBlockLoader private final int expectedBatchId = batchId; private final int columnIndex; private final Type type; + private final FileFormatDataSourceStats stats; private boolean loaded; - public OrcBlockLoader(int columnIndex, Type type) + public OrcBlockLoader(int columnIndex, Type type, FileFormatDataSourceStats stats) { this.columnIndex = columnIndex; this.type = requireNonNull(type, "type is null"); + this.stats = requireNonNull(stats, "stats is null"); } @Override @@ -240,6 +249,8 @@ public final void load(LazyBlock lazyBlock) throw new PrestoException(HIVE_CURSOR_ERROR, e); } + stats.addLoadedBlockSize(lazyBlock.getSizeInBytes()); + loaded = true; } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java index 55e7a9dc78834..90de4f810c030 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java @@ -13,11 +13,13 @@ */ package com.facebook.presto.hive.orc; +import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveClientConfig; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.hive.HivePageSourceFactory; import com.facebook.presto.orc.OrcDataSource; +import com.facebook.presto.orc.OrcDataSourceId; import com.facebook.presto.orc.OrcPredicate; import com.facebook.presto.orc.OrcReader; import com.facebook.presto.orc.OrcRecordReader; @@ -58,6 +60,7 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_MISSING_DATA; import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxBufferSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxMergeDistance; +import static com.facebook.presto.hive.HiveSessionProperties.getOrcMaxReadBlockSize; import static com.facebook.presto.hive.HiveSessionProperties.getOrcStreamBufferSize; import static com.facebook.presto.hive.HiveSessionProperties.isOrcBloomFiltersEnabled; import static com.facebook.presto.hive.HiveUtil.isDeserializerClass; @@ -72,18 +75,20 @@ public class OrcPageSourceFactory private final TypeManager typeManager; private final boolean useOrcColumnNames; private final HdfsEnvironment hdfsEnvironment; + private final FileFormatDataSourceStats stats; @Inject - public OrcPageSourceFactory(TypeManager typeManager, HiveClientConfig config, HdfsEnvironment hdfsEnvironment) + public OrcPageSourceFactory(TypeManager typeManager, HiveClientConfig config, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats) { - this(typeManager, requireNonNull(config, "hiveClientConfig is null").isUseOrcColumnNames(), hdfsEnvironment); + this(typeManager, requireNonNull(config, "hiveClientConfig is null").isUseOrcColumnNames(), hdfsEnvironment, stats); } - public OrcPageSourceFactory(TypeManager typeManager, boolean useOrcColumnNames, HdfsEnvironment hdfsEnvironment) + public OrcPageSourceFactory(TypeManager typeManager, boolean useOrcColumnNames, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.useOrcColumnNames = useOrcColumnNames; this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); + this.stats = requireNonNull(stats, "stats is null"); } @Override @@ -118,7 +123,9 @@ public Optional createPageSource( getOrcMaxMergeDistance(session), getOrcMaxBufferSize(session), getOrcStreamBufferSize(session), - isOrcBloomFiltersEnabled(session))); + getOrcMaxReadBlockSize(session), + isOrcBloomFiltersEnabled(session), + stats)); } public static OrcPageSource createOrcPageSource( @@ -137,14 +144,16 @@ public static OrcPageSource createOrcPageSource( DataSize maxMergeDistance, DataSize maxBufferSize, DataSize streamBufferSize, - boolean orcBloomFiltersEnabled) + DataSize maxReadBlockSize, + boolean orcBloomFiltersEnabled, + FileFormatDataSourceStats stats) { OrcDataSource orcDataSource; try { FileSystem fileSystem = hdfsEnvironment.getFileSystem(sessionUser, path, configuration); long size = fileSystem.getFileStatus(path).getLen(); FSDataInputStream inputStream = fileSystem.open(path); - orcDataSource = new HdfsOrcDataSource(path.toString(), size, maxMergeDistance, maxBufferSize, streamBufferSize, inputStream); + orcDataSource = new HdfsOrcDataSource(new OrcDataSourceId(path.toString()), size, maxMergeDistance, maxBufferSize, streamBufferSize, inputStream, stats); } catch (Exception e) { if (nullToEmpty(e.getMessage()).trim().equals("Filesystem closed") || @@ -156,7 +165,7 @@ public static OrcPageSource createOrcPageSource( AggregatedMemoryContext systemMemoryUsage = new AggregatedMemoryContext(); try { - OrcReader reader = new OrcReader(orcDataSource, metadataReader, maxMergeDistance, maxBufferSize); + OrcReader reader = new OrcReader(orcDataSource, metadataReader, maxMergeDistance, maxBufferSize, maxReadBlockSize); List physicalColumns = getPhysicalHiveColumnHandles(columns, useOrcColumnNames, reader, path); ImmutableMap.Builder includedColumns = ImmutableMap.builder(); @@ -184,7 +193,8 @@ public static OrcPageSource createOrcPageSource( orcDataSource, physicalColumns, typeManager, - systemMemoryUsage); + systemMemoryUsage, + stats); } catch (Exception e) { try { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java index d0199f546cf9e..7ac5883045e6d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetReader.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.InterleavedBlock; import com.facebook.presto.spi.block.RunLengthEncodedBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.NamedTypeSignature; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -203,14 +204,13 @@ private Block readMap(Type type, List path, IntList elementOffsets) } return RunLengthEncodedBlock.create(parameters.get(0), null, batchSize); } - InterleavedBlock interleavedBlock = new InterleavedBlock(new Block[] {blocks[0], blocks[1]}); int[] offsets = new int[batchSize + 1]; for (int i = 1; i < offsets.length; i++) { - int elementPositionCount = keyOffsets.getInt(i - 1) * 2; - elementOffsets.add(elementPositionCount); + int elementPositionCount = keyOffsets.getInt(i - 1); + elementOffsets.add(elementPositionCount * 2); offsets[i] = offsets[i - 1] + elementPositionCount; } - return new ArrayBlock(batchSize, new boolean[batchSize], offsets, interleavedBlock); + return ((MapType) type).createBlockFromKeyValue(new boolean[batchSize], offsets, blocks[0], blocks[1]); } public Block readStruct(Type type, List path) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java index b96dd1efe59e7..b47d42115baac 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/reader/ParquetShortDecimalColumnReader.java @@ -59,7 +59,15 @@ else if (columnDescriptor.getType().equals(INT64)) { protected void skipValue() { if (definitionLevel == columnDescriptor.getMaxDefinitionLevel()) { - valuesReader.readBytes(); + if (columnDescriptor.getType().equals(INT32)) { + valuesReader.readInteger(); + } + else if (columnDescriptor.getType().equals(INT64)) { + valuesReader.readLong(); + } + else { + valuesReader.readBytes(); + } } } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/HdfsRcFileDataSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/HdfsRcFileDataSource.java index 2bd6ea1086ff0..1e29dfff2fe5b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/HdfsRcFileDataSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/HdfsRcFileDataSource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.rcfile; +import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.rcfile.RcFileDataSource; import org.apache.hadoop.fs.FSDataInputStream; @@ -27,15 +28,17 @@ public class HdfsRcFileDataSource private final FSDataInputStream inputStream; private final String path; private final long size; + private final FileFormatDataSourceStats stats; private long readTimeNanos; private long readBytes; - public HdfsRcFileDataSource(String path, FSDataInputStream inputStream, long size) + public HdfsRcFileDataSource(String path, FSDataInputStream inputStream, long size, FileFormatDataSourceStats stats) { this.path = requireNonNull(path, "path is null"); this.inputStream = requireNonNull(inputStream, "inputStream is null"); this.size = size; checkArgument(size >= 0, "size is negative"); + this.stats = requireNonNull(stats, "stats is null"); } @Override @@ -71,7 +74,10 @@ public void readFully(long position, byte[] buffer, int bufferOffset, int buffer inputStream.readFully(position, buffer, bufferOffset, bufferLength); - readTimeNanos += System.nanoTime() - start; + long readDuration = System.nanoTime() - start; + stats.readDataBytesPerSecond(bufferLength, readDuration); + + readTimeNanos += readDuration; readBytes += bufferLength; } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java index 8a525bd6268fb..ec845329b0f7c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.rcfile; +import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.hive.HivePageSourceFactory; @@ -52,7 +53,6 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; import static com.facebook.presto.hive.HiveErrorCode.HIVE_MISSING_DATA; -import static com.facebook.presto.hive.HiveSessionProperties.isRcfileOptimizedReaderEnabled; import static com.facebook.presto.hive.HiveUtil.getDeserializerClassName; import static com.facebook.presto.rcfile.text.TextRcFileEncoding.DEFAULT_NULL_SEQUENCE; import static com.facebook.presto.rcfile.text.TextRcFileEncoding.DEFAULT_SEPARATORS; @@ -77,12 +77,14 @@ public class RcFilePageSourceFactory private final TypeManager typeManager; private final HdfsEnvironment hdfsEnvironment; + private final FileFormatDataSourceStats stats; @Inject - public RcFilePageSourceFactory(TypeManager typeManager, HdfsEnvironment hdfsEnvironment) + public RcFilePageSourceFactory(TypeManager typeManager, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); + this.stats = requireNonNull(stats, "stats is null"); } @Override @@ -97,10 +99,6 @@ public Optional createPageSource( TupleDomain effectivePredicate, DateTimeZone hiveStorageTimeZone) { - if (!isRcfileOptimizedReaderEnabled(session)) { - return Optional.empty(); - } - RcFileEncoding rcFileEncoding; String deserializerClassName = getDeserializerClassName(schema); if (deserializerClassName.equals(LazyBinaryColumnarSerDe.class.getName())) { @@ -135,7 +133,7 @@ else if (deserializerClassName.equals(ColumnarSerDe.class.getName())) { } RcFileReader rcFileReader = new RcFileReader( - new HdfsRcFileDataSource(path.toString(), inputStream, size), + new HdfsRcFileDataSource(path.toString(), inputStream, size, stats), rcFileEncoding, readColumns.build(), new AircompressorCodecFactory(new HadoopCodecFactory(configuration.getClassLoader())), diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java index 1f423a168c3ac..731b96e01e633 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java @@ -187,12 +187,12 @@ public void checkCanSetCatalogSessionProperty(Identity identity, String property } @Override - public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String grantee, boolean withGrantOption) { } @Override - public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String revokee, boolean grantOptionFor) { } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java index 84c35b97f9391..8b0fa940109c1 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java @@ -247,7 +247,7 @@ public void checkCanSetCatalogSessionProperty(Identity identity, String property } @Override - public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String grantee, boolean withGrantOption) { if (checkTablePermission(transaction, identity, tableName, OWNERSHIP)) { return; @@ -260,7 +260,7 @@ public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, } @Override - public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String revokee, boolean grantOptionFor) { if (checkTablePermission(transaction, identity, tableName, OWNERSHIP)) { return; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/HiveStatisticsProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/HiveStatisticsProvider.java new file mode 100644 index 0000000000000..fe61ef09d507f --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/HiveStatisticsProvider.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.hive.statistics; + +import com.facebook.presto.hive.HivePartition; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.statistics.TableStatistics; + +import java.util.List; +import java.util.Map; + +public interface HiveStatisticsProvider +{ + TableStatistics getTableStatistics( + ConnectorSession session, + ConnectorTableHandle tableHandle, + List hivePartitions, + Map tableColumns); +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java new file mode 100644 index 0000000000000..6b69fef2c656f --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/MetastoreHiveStatisticsProvider.java @@ -0,0 +1,302 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.hive.statistics; + +import com.facebook.presto.hive.HiveColumnHandle; +import com.facebook.presto.hive.HivePartition; +import com.facebook.presto.hive.HiveTableHandle; +import com.facebook.presto.hive.PartitionStatistics; +import com.facebook.presto.hive.metastore.HiveColumnStatistics; +import com.facebook.presto.hive.metastore.Partition; +import com.facebook.presto.hive.metastore.SemiTransactionalHiveMetastore; +import com.facebook.presto.hive.metastore.Table; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.statistics.ColumnStatistics; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.TypeManager; +import com.google.common.collect.ImmutableMap; + +import javax.annotation.Nullable; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalLong; +import java.util.PrimitiveIterator; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; + +import static com.facebook.presto.hive.HiveSessionProperties.isStatisticsEnabled; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; + +public class MetastoreHiveStatisticsProvider + implements HiveStatisticsProvider +{ + private final TypeManager typeManager; + private final SemiTransactionalHiveMetastore metastore; + + public MetastoreHiveStatisticsProvider(TypeManager typeManager, SemiTransactionalHiveMetastore metastore) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.metastore = requireNonNull(metastore, "metastore is null"); + } + + @Override + public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, List hivePartitions, Map tableColumns) + { + if (!isStatisticsEnabled(session)) { + return TableStatistics.EMPTY_STATISTICS; + } + Map partitionStatistics = getPartitionsStatistics((HiveTableHandle) tableHandle, hivePartitions, tableColumns.keySet()); + + TableStatistics.Builder tableStatistics = TableStatistics.builder(); + tableStatistics.setRowCount(calculateRowsCount(partitionStatistics)); + for (Map.Entry columnEntry : tableColumns.entrySet()) { + String columnName = columnEntry.getKey(); + HiveColumnHandle hiveColumnHandle = (HiveColumnHandle) columnEntry.getValue(); + if (getColumnMetadata(hiveColumnHandle).isHidden()) { + continue; + } + ColumnStatistics.Builder columnStatistics = ColumnStatistics.builder(); + if (hiveColumnHandle.isPartitionKey()) { + columnStatistics.setDistinctValuesCount(countDistinctPartitionKeys(hiveColumnHandle, hivePartitions)); + columnStatistics.setNullsCount(calculateNullsCountForPartitioningKey(hiveColumnHandle, hivePartitions, partitionStatistics)); + } + else { + columnStatistics.setDistinctValuesCount(calculateDistinctValuesCount(partitionStatistics, columnName)); + columnStatistics.setNullsCount(calculateNullsCount(partitionStatistics, columnName)); + } + tableStatistics.setColumnStatistics(hiveColumnHandle, columnStatistics.build()); + } + return tableStatistics.build(); + } + + private Estimate calculateRowsCount(Map partitionStatistics) + { + List knownPartitionRowCounts = partitionStatistics.values().stream() + .map(PartitionStatistics::getRowCount) + .filter(OptionalLong::isPresent) + .map(OptionalLong::getAsLong) + .collect(toList()); + + long knownPartitionRowCountsSum = knownPartitionRowCounts.stream().mapToLong(a -> a).sum(); + long partitionsWithStatsCount = knownPartitionRowCounts.size(); + long allPartitionsCount = partitionStatistics.size(); + + if (partitionsWithStatsCount == 0) { + return Estimate.unknownValue(); + } + return new Estimate(1.0 * knownPartitionRowCountsSum / partitionsWithStatsCount * allPartitionsCount); + } + + private Estimate calculateDistinctValuesCount(Map statisticsByPartitionName, String column) + { + return summarizePartitionStatistics( + statisticsByPartitionName.values(), + column, + columnStatistics -> { + if (columnStatistics.getDistinctValuesCount().isPresent()) { + return OptionalDouble.of(columnStatistics.getDistinctValuesCount().getAsLong()); + } + else { + return OptionalDouble.empty(); + } + }, + DoubleStream::max); + } + + private Estimate calculateNullsCount(Map statisticsByPartitionName, String column) + { + return summarizePartitionStatistics( + statisticsByPartitionName.values(), + column, + columnStatistics -> { + if (columnStatistics.getNullsCount().isPresent()) { + return OptionalDouble.of(columnStatistics.getNullsCount().getAsLong()); + } + else { + return OptionalDouble.empty(); + } + }, + nullsCountStream -> { + double totalNullsCount = 0; + long partitionsWithStatisticsCount = 0; + for (PrimitiveIterator.OfDouble nullsCountIterator = nullsCountStream.iterator(); nullsCountIterator.hasNext(); ) { + double nullsCount = nullsCountIterator.nextDouble(); + totalNullsCount += nullsCount; + partitionsWithStatisticsCount++; + } + + if (partitionsWithStatisticsCount == 0) { + return OptionalDouble.empty(); + } + else { + int allPartitionsCount = statisticsByPartitionName.size(); + return OptionalDouble.of(allPartitionsCount / partitionsWithStatisticsCount * totalNullsCount); + } + }); + } + + private Estimate countDistinctPartitionKeys(HiveColumnHandle partitionColumn, List partitions) + { + return new Estimate(partitions.stream() + .map(HivePartition::getKeys) + .map(keys -> keys.get(partitionColumn)) + .distinct() + .count()); + } + + private Estimate calculateNullsCountForPartitioningKey(HiveColumnHandle partitionColumn, List partitions, Map partitionStatistics) + { + OptionalDouble rowsPerPartition = partitionStatistics.values().stream() + .map(PartitionStatistics::getRowCount) + .filter(OptionalLong::isPresent) + .mapToLong(OptionalLong::getAsLong) + .average(); + + if (!rowsPerPartition.isPresent()) { + return Estimate.unknownValue(); + } + + return new Estimate(partitions.stream() + .filter(partition -> partition.getKeys().get(partitionColumn).isNull()) + .map(HivePartition::getPartitionId) + .mapToLong(partitionId -> partitionStatistics.get(partitionId).getRowCount().orElse((long) rowsPerPartition.getAsDouble())) + .sum()); + } + + private Estimate summarizePartitionStatistics( + Collection partitionStatistics, + String column, + Function valueExtractFunction, + Function valueAggregateFunction) + { + DoubleStream intermediateStream = partitionStatistics.stream() + .map(PartitionStatistics::getColumnStatistics) + .filter(stats -> stats.containsKey(column)) + .map(stats -> stats.get(column)) + .map(valueExtractFunction) + .filter(OptionalDouble::isPresent) + .mapToDouble(OptionalDouble::getAsDouble); + + OptionalDouble statisticsValue = valueAggregateFunction.apply(intermediateStream); + + if (statisticsValue.isPresent()) { + return new Estimate(statisticsValue.getAsDouble()); + } + else { + return Estimate.unknownValue(); + } + } + + private Map getPartitionsStatistics(HiveTableHandle tableHandle, List hivePartitions, Set tableColumns) + { + if (hivePartitions.isEmpty()) { + return ImmutableMap.of(); + } + boolean unpartitioned = hivePartitions.stream().anyMatch(partition -> partition.getPartitionId().equals(HivePartition.UNPARTITIONED_ID)); + if (unpartitioned) { + checkArgument(hivePartitions.size() == 1, "expected only one hive partition"); + } + + if (unpartitioned) { + return ImmutableMap.of(HivePartition.UNPARTITIONED_ID, getTableStatistics(tableHandle.getSchemaTableName(), tableColumns)); + } + else { + return getPartitionsStatistics(tableHandle.getSchemaTableName(), hivePartitions, tableColumns); + } + } + + private Map getPartitionsStatistics(SchemaTableName schemaTableName, List hivePartitions, Set tableColumns) + { + String databaseName = schemaTableName.getSchemaName(); + String tableName = schemaTableName.getTableName(); + + ImmutableMap.Builder resultMap = ImmutableMap.builder(); + + List partitionNames = hivePartitions.stream().map(HivePartition::getPartitionId).collect(Collectors.toList()); + Map> partitionColumnStatisticsMap = + metastore.getPartitionColumnStatistics(databaseName, tableName, new HashSet<>(partitionNames), tableColumns) + .orElse(ImmutableMap.of()); + + Map> partitionsByNames = metastore.getPartitionsByNames(databaseName, tableName, partitionNames); + for (String partitionName : partitionNames) { + Map partitionParameters = partitionsByNames.get(partitionName) + .map(Partition::getParameters) + .orElseThrow(() -> new IllegalArgumentException(format("Could not get metadata for partition %s.%s.%s", databaseName, tableName, partitionName))); + Map partitionColumnStatistics = partitionColumnStatisticsMap.getOrDefault(partitionName, ImmutableMap.of()); + resultMap.put(partitionName, readStatisticsFromParameters(partitionParameters, partitionColumnStatistics)); + } + + return resultMap.build(); + } + + private PartitionStatistics getTableStatistics(SchemaTableName schemaTableName, Set tableColumns) + { + String databaseName = schemaTableName.getSchemaName(); + String tableName = schemaTableName.getTableName(); + Table table = metastore.getTable(databaseName, tableName) + .orElseThrow(() -> new IllegalArgumentException(format("Could not get metadata for table %s.%s", databaseName, tableName))); + + Map tableColumnStatistics = metastore.getTableColumnStatistics(databaseName, tableName, tableColumns).orElse(ImmutableMap.of()); + + return readStatisticsFromParameters(table.getParameters(), tableColumnStatistics); + } + + private PartitionStatistics readStatisticsFromParameters(Map parameters, Map columnStatistics) + { + boolean columnStatsAcurate = Boolean.valueOf(Optional.ofNullable(parameters.get("COLUMN_STATS_ACCURATE")).orElse("false")); + OptionalLong numFiles = convertStringParameter(parameters.get("numFiles")); + OptionalLong numRows = convertStringParameter(parameters.get("numRows")); + OptionalLong rawDataSize = convertStringParameter(parameters.get("rawDataSize")); + OptionalLong totalSize = convertStringParameter(parameters.get("totalSize")); + return new PartitionStatistics(columnStatsAcurate, numFiles, numRows, rawDataSize, totalSize, columnStatistics); + } + + private OptionalLong convertStringParameter(@Nullable String parameterValue) + { + if (parameterValue == null) { + return OptionalLong.empty(); + } + try { + long longValue = Long.parseLong(parameterValue); + if (longValue < 0) { + return OptionalLong.empty(); + } + return OptionalLong.of(longValue); + } + catch (NumberFormatException e) { + return OptionalLong.empty(); + } + } + + private ColumnMetadata getColumnMetadata(ColumnHandle columnHandle) + { + return ((HiveColumnHandle) columnHandle).getColumnMetadata(typeManager); + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java b/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java index b9bdd521bdae4..a8ffe3df70191 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/thrift/Transport.java @@ -82,7 +82,7 @@ private static void closeQuietly(Closeable closeable) private static Socket createSocksSocket(HostAndPort proxy) { - SocketAddress address = InetSocketAddress.createUnresolved(proxy.getHostText(), proxy.getPort()); + SocketAddress address = InetSocketAddress.createUnresolved(proxy.getHost(), proxy.getPort()); return new Socket(new Proxy(Proxy.Type.SOCKS, address)); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java index 27d9d0854c3e9..b647ab89791cd 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.util; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapred.JobConf; import java.util.Map; @@ -27,4 +28,12 @@ public static void copy(Configuration from, Configuration to) to.set(entry.getKey(), entry.getValue()); } } + + public static JobConf toJobConf(Configuration conf) + { + if (conf instanceof JobConf) { + return (JobConf) conf; + } + return new JobConf(conf); + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/SerDeUtils.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/SerDeUtils.java index f5286fb9c9fcb..26b8a5de987db 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/SerDeUtils.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/SerDeUtils.java @@ -216,12 +216,13 @@ private static Block serializeMap(Type type, BlockBuilder builder, Object object ObjectInspector keyInspector = inspector.getMapKeyObjectInspector(); ObjectInspector valueInspector = inspector.getMapValueObjectInspector(); BlockBuilder currentBuilder; - if (builder != null) { - currentBuilder = builder.beginBlockEntry(); - } - else { - currentBuilder = new InterleavedBlockBuilder(typeParameters, new BlockBuilderStatus(), map.size()); + + boolean builderSynthesized = false; + if (builder == null) { + builderSynthesized = true; + builder = type.createBlockBuilder(new BlockBuilderStatus(), 1); } + currentBuilder = builder.beginBlockEntry(); for (Map.Entry entry : map.entrySet()) { // Hive skips map entries with null keys @@ -231,13 +232,12 @@ private static Block serializeMap(Type type, BlockBuilder builder, Object object } } - if (builder != null) { - builder.closeEntry(); - return null; + builder.closeEntry(); + if (builderSynthesized) { + return (Block) type.getObject(builder, 0); } else { - Block resultBlock = currentBuilder.build(); - return resultBlock; + return null; } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java index bc35994d93097..d1b1ec6ec4dd1 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java @@ -66,6 +66,8 @@ import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.predicate.ValueSet; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.NamedTypeSignature; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlTimestamp; @@ -78,8 +80,6 @@ import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.TestingConnectorSession; import com.facebook.presto.testing.TestingNodeManager; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; @@ -124,8 +124,9 @@ import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_PARTITION_VALUE; -import static com.facebook.presto.hive.HiveErrorCode.HIVE_METASTORE_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_SCHEMA_MISMATCH; +import static com.facebook.presto.hive.HiveMetadata.PRESTO_QUERY_ID_NAME; +import static com.facebook.presto.hive.HiveMetadata.PRESTO_VERSION_NAME; import static com.facebook.presto.hive.HiveMetadata.convertToPredicate; import static com.facebook.presto.hive.HiveStorageFormat.AVRO; import static com.facebook.presto.hive.HiveStorageFormat.DWRF; @@ -1024,9 +1025,9 @@ public void testGetTableSchemaOffline() { try (Transaction transaction = newTransaction()) { ConnectorMetadata metadata = transaction.getMetadata(); - ConnectorTableHandle tableHandle = getTableHandle(metadata, tableOffline); - ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(newSession(), tableHandle); - Map map = uniqueIndex(tableMetadata.getColumns(), ColumnMetadata::getName); + Map> columns = metadata.listTableColumns(newSession(), tableOffline.toSchemaTablePrefix()); + assertEquals(columns.size(), 1); + Map map = uniqueIndex(getOnlyElement(columns.values()), ColumnMetadata::getName); assertPrimitiveField(map, "t_string", createUnboundedVarcharType(), false); } @@ -1114,9 +1115,8 @@ public void testGetPartitionTableOffline() { try (Transaction transaction = newTransaction()) { ConnectorMetadata metadata = transaction.getMetadata(); - ConnectorTableHandle tableHandle = getTableHandle(metadata, tableOffline); try { - metadata.getTableLayouts(newSession(), tableHandle, new Constraint<>(TupleDomain.all(), bindings -> true), Optional.empty()); + getTableHandle(metadata, tableOffline); fail("expected TableOfflineException"); } catch (TableOfflineException e) { @@ -1523,35 +1523,6 @@ public void testTypesRcText() assertGetRecords("presto_test_types_rctext", RCTEXT); } - @Test - public void testTypesRcTextRecordCursor() - throws Exception - { - try (Transaction transaction = newTransaction()) { - ConnectorSession session = newSession(); - ConnectorMetadata metadata = transaction.getMetadata(); - - if (metadata.getTableHandle(session, new SchemaTableName(database, "presto_test_types_rctext")) == null) { - return; - } - - ConnectorTableHandle tableHandle = getTableHandle(metadata, new SchemaTableName(database, "presto_test_types_rctext")); - ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); - HiveSplit hiveSplit = getHiveSplit(tableHandle); - List columnHandles = ImmutableList.copyOf(metadata.getColumnHandles(session, tableHandle).values()); - - ConnectorPageSourceProvider pageSourceProvider = new HivePageSourceProvider( - new HiveClientConfig().setTimeZone(timeZone.getID()), - hdfsEnvironment, - ImmutableSet.of(new ColumnarTextHiveRecordCursorProvider(hdfsEnvironment)), - ImmutableSet.of(), - TYPE_MANAGER); - - ConnectorPageSource pageSource = pageSourceProvider.createPageSource(transaction.getTransactionHandle(), session, hiveSplit, columnHandles); - assertGetRecords(RCTEXT, tableMetadata, hiveSplit, pageSource, columnHandles); - } - } - @Test public void testTypesRcBinary() throws Exception @@ -1559,35 +1530,6 @@ public void testTypesRcBinary() assertGetRecords("presto_test_types_rcbinary", RCBINARY); } - @Test - public void testTypesRcBinaryRecordCursor() - throws Exception - { - try (Transaction transaction = newTransaction()) { - ConnectorSession session = newSession(); - ConnectorMetadata metadata = transaction.getMetadata(); - - if (metadata.getTableHandle(session, new SchemaTableName(database, "presto_test_types_rcbinary")) == null) { - return; - } - - ConnectorTableHandle tableHandle = getTableHandle(metadata, new SchemaTableName(database, "presto_test_types_rcbinary")); - ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); - HiveSplit hiveSplit = getHiveSplit(tableHandle); - List columnHandles = ImmutableList.copyOf(metadata.getColumnHandles(session, tableHandle).values()); - - ConnectorPageSourceProvider pageSourceProvider = new HivePageSourceProvider( - new HiveClientConfig().setTimeZone(timeZone.getID()), - hdfsEnvironment, - ImmutableSet.of(new ColumnarBinaryHiveRecordCursorProvider(hdfsEnvironment)), - ImmutableSet.of(), - TYPE_MANAGER); - - ConnectorPageSource pageSource = pageSourceProvider.createPageSource(transaction.getTransactionHandle(), session, hiveSplit, columnHandles); - assertGetRecords(RCBINARY, tableMetadata, hiveSplit, pageSource, columnHandles); - } - } - @Test public void testTypesOrc() throws Exception @@ -1973,8 +1915,8 @@ protected void doCreateTable(SchemaTableName tableName, HiveStorageFormat storag // verify the node version and query ID in table Table table = getMetastoreClient(tableName.getSchemaName()).getTable(tableName.getSchemaName(), tableName.getTableName()).get(); - assertEquals(table.getParameters().get(HiveMetadata.PRESTO_VERSION_NAME), TEST_SERVER_VERSION); - assertEquals(table.getParameters().get(HiveMetadata.PRESTO_QUERY_ID_NAME), queryId); + assertEquals(table.getParameters().get(PRESTO_VERSION_NAME), TEST_SERVER_VERSION); + assertEquals(table.getParameters().get(PRESTO_QUERY_ID_NAME), queryId); } } @@ -2022,8 +1964,8 @@ protected void doCreateEmptyTable(SchemaTableName tableName, HiveStorageFormat s assertEquals(table.getStorage().getStorageFormat().getInputFormat(), storageFormat.getInputFormat()); // verify the node version and query ID - assertEquals(table.getParameters().get(HiveMetadata.PRESTO_VERSION_NAME), TEST_SERVER_VERSION); - assertEquals(table.getParameters().get(HiveMetadata.PRESTO_QUERY_ID_NAME), queryId); + assertEquals(table.getParameters().get(PRESTO_VERSION_NAME), TEST_SERVER_VERSION); + assertEquals(table.getParameters().get(PRESTO_QUERY_ID_NAME), queryId); // verify the table is empty List columnHandles = filterNonHiddenColumnHandles(metadata.getColumnHandles(session, tableHandle).values()); @@ -2211,7 +2153,7 @@ private void doInsertIntoNewPartition(HiveStorageFormat storageFormat, SchemaTab try (Transaction transaction = newTransaction()) { // verify partitions were created List partitionNames = transaction.getMetastore(tableName.getSchemaName()).getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder(partitionNames, CREATE_TABLE_PARTITIONED_DATA.getMaterializedRows().stream() .map(row -> "ds=" + row.getField(CREATE_TABLE_PARTITIONED_DATA.getTypes().size() - 1)) .collect(toList())); @@ -2221,8 +2163,8 @@ private void doInsertIntoNewPartition(HiveStorageFormat storageFormat, SchemaTab assertEquals(partitions.size(), partitionNames.size()); for (String partitionName : partitionNames) { Partition partition = partitions.get(partitionName).get(); - assertEquals(partition.getParameters().get(HiveMetadata.PRESTO_VERSION_NAME), TEST_SERVER_VERSION); - assertEquals(partition.getParameters().get(HiveMetadata.PRESTO_QUERY_ID_NAME), queryId); + assertEquals(partition.getParameters().get(PRESTO_VERSION_NAME), TEST_SERVER_VERSION); + assertEquals(partition.getParameters().get(PRESTO_QUERY_ID_NAME), queryId); } // load the new table @@ -2328,7 +2270,7 @@ private void doInsertIntoExistingPartition(HiveStorageFormat storageFormat, Sche // verify partitions were created List partitionNames = transaction.getMetastore(tableName.getSchemaName()).getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder(partitionNames, CREATE_TABLE_PARTITIONED_DATA.getMaterializedRows().stream() .map(row -> "ds=" + row.getField(CREATE_TABLE_PARTITIONED_DATA.getTypes().size() - 1)) .collect(toList())); @@ -2451,7 +2393,7 @@ private void doTestMetadataDelete(HiveStorageFormat storageFormat, SchemaTableNa // verify partitions were created List partitionNames = transaction.getMetastore(tableName.getSchemaName()).getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder(partitionNames, CREATE_TABLE_PARTITIONED_DATA.getMaterializedRows().stream() .map(row -> "ds=" + row.getField(CREATE_TABLE_PARTITIONED_DATA.getTypes().size() - 1)) .collect(toList())); @@ -2916,10 +2858,6 @@ protected static void assertPageSourceType(ConnectorPageSource pageSource, HiveS private static Class recordCursorType(HiveStorageFormat hiveStorageFormat) { switch (hiveStorageFormat) { - case RCTEXT: - return ColumnarTextHiveRecordCursor.class; - case RCBINARY: - return ColumnarBinaryHiveRecordCursor.class; case PARQUET: return ParquetHiveRecordCursor.class; } @@ -3089,7 +3027,9 @@ protected void createEmptyTable(SchemaTableName schemaTableName, HiveStorageForm .setTableName(tableName) .setOwner(tableOwner) .setTableType(TableType.MANAGED_TABLE.name()) - .setParameters(ImmutableMap.of()) + .setParameters(ImmutableMap.of( + PRESTO_VERSION_NAME, TEST_SERVER_VERSION, + PRESTO_QUERY_ID_NAME, session.getQueryId())) .setDataColumns(columns) .setPartitionColumns(partitionColumns); @@ -3311,7 +3251,7 @@ private void doTestTransactionDeleteInsert( // verify partitions List partitionNames = transaction.getMetastore(tableName.getSchemaName()) .getPartitionNames(tableName.getSchemaName(), tableName.getTableName()) - .orElseThrow(() -> new PrestoException(HIVE_METASTORE_ERROR, "Partition metadata not available")); + .orElseThrow(() -> new AssertionError("Table does not exist: " + tableName)); assertEqualsIgnoreOrder( partitionNames, expectedData.getMaterializedRows().stream() diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClientLocal.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClientLocal.java index fd49fef535213..6ee20e1c28f57 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClientLocal.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClientLocal.java @@ -80,4 +80,7 @@ public void testGetAllTableColumnsInSchema() {} @Override public void testGetTableNames() {} + + @Override + public void testGetTableSchemaOffline() {} } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java index 92ba173db541b..ea56b623f34d0 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java @@ -23,10 +23,12 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DateType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTimestamp; @@ -36,9 +38,6 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.tests.StructuralTestUtil; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -91,6 +90,7 @@ import static com.facebook.presto.hive.HivePartitionKey.HIVE_DEFAULT_DYNAMIC_PARTITION; import static com.facebook.presto.hive.HiveTestUtils.SESSION; import static com.facebook.presto.hive.HiveTestUtils.TYPE_MANAGER; +import static com.facebook.presto.hive.HiveTestUtils.mapType; import static com.facebook.presto.hive.HiveUtil.isStructuralType; import static com.facebook.presto.hive.util.SerDeUtils.serializeObject; import static com.facebook.presto.spi.type.BigintType.BIGINT; @@ -421,14 +421,14 @@ public abstract class AbstractTestHiveFileFormats getStandardMapObjectInspector(javaLongObjectInspector, javaBooleanObjectInspector) ), asMap(new String[] {null, "k"}, new ImmutableMap[] {ImmutableMap.of(15L, true), ImmutableMap.of(16L, false)}), - mapBlockOf(createUnboundedVarcharType(), new MapType(BIGINT, BOOLEAN), "k", mapBlockOf(BIGINT, BOOLEAN, 16L, false)))) + mapBlockOf(createUnboundedVarcharType(), mapType(BIGINT, BOOLEAN), "k", mapBlockOf(BIGINT, BOOLEAN, 16L, false)))) .add(new TestColumn("t_map_null_key_complex_key_value", getStandardMapObjectInspector( getStandardListObjectInspector(javaStringObjectInspector), getStandardMapObjectInspector(javaLongObjectInspector, javaBooleanObjectInspector) ), asMap(new ImmutableList[] {null, ImmutableList.of("k", "ka")}, new ImmutableMap[] {ImmutableMap.of(15L, true), ImmutableMap.of(16L, false)}), - mapBlockOf(new ArrayType(createUnboundedVarcharType()), new MapType(BIGINT, BOOLEAN), arrayBlockOf(createUnboundedVarcharType(), "k", "ka"), mapBlockOf(BIGINT, BOOLEAN, 16L, false)))) + mapBlockOf(new ArrayType(createUnboundedVarcharType()), mapType(BIGINT, BOOLEAN), arrayBlockOf(createUnboundedVarcharType(), "k", "ka"), mapBlockOf(BIGINT, BOOLEAN, 16L, false)))) .add(new TestColumn("t_struct_nested", getStandardStructObjectInspector(ImmutableList.of("struct_field"), ImmutableList.of(getStandardListObjectInspector(javaStringObjectInspector))), ImmutableList.of(ImmutableList.of("1", "2", "3")), rowBlockOf(ImmutableList.of(new ArrayType(createUnboundedVarcharType())), arrayBlockOf(createUnboundedVarcharType(), "1", "2", "3")))) .add(new TestColumn("t_struct_null", getStandardStructObjectInspector(ImmutableList.of("struct_field_null", "struct_field_null2"), diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/HiveQueryRunner.java b/presto-hive/src/test/java/com/facebook/presto/hive/HiveQueryRunner.java index 9319366125f43..11f1d22bfd947 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/HiveQueryRunner.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/HiveQueryRunner.java @@ -102,6 +102,7 @@ public static DistributedQueryRunner createQueryRunner(Iterable> ta .put("hive.metastore.uri", "thrift://localhost:8080") .put("hive.time-zone", TIME_ZONE.getID()) .put("hive.security", security) + .put("hive.max-partitions-per-scan", "1000") .build(); Map hiveBucketedProperties = ImmutableMap.builder() .putAll(hiveProperties) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java index d7fae37a4848f..0c507cfd6835e 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java @@ -13,15 +13,21 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.hive.authentication.NoHdfsAuthentication; import com.facebook.presto.hive.orc.DwrfPageSourceFactory; import com.facebook.presto.hive.orc.OrcPageSourceFactory; import com.facebook.presto.hive.parquet.ParquetPageSourceFactory; import com.facebook.presto.hive.parquet.ParquetRecordCursorProvider; import com.facebook.presto.hive.rcfile.RcFilePageSourceFactory; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.testing.TestingConnectorSession; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; @@ -40,16 +46,21 @@ private HiveTestUtils() new HiveSessionProperties(new HiveClientConfig()).getSessionProperties()); public static final TypeRegistry TYPE_MANAGER = new TypeRegistry(); + static { + // associate TYPE_MANAGER with a function registry + new FunctionRegistry(TYPE_MANAGER, new BlockEncodingManager(TYPE_MANAGER), new FeaturesConfig()); + } public static final HdfsEnvironment HDFS_ENVIRONMENT = createTestHdfsEnvironment(new HiveClientConfig()); public static Set getDefaultHiveDataStreamFactories(HiveClientConfig hiveClientConfig) { + FileFormatDataSourceStats stats = new FileFormatDataSourceStats(); HdfsEnvironment testHdfsEnvironment = createTestHdfsEnvironment(hiveClientConfig); return ImmutableSet.builder() - .add(new RcFilePageSourceFactory(TYPE_MANAGER, testHdfsEnvironment)) - .add(new OrcPageSourceFactory(TYPE_MANAGER, hiveClientConfig, testHdfsEnvironment)) - .add(new DwrfPageSourceFactory(TYPE_MANAGER, testHdfsEnvironment)) + .add(new RcFilePageSourceFactory(TYPE_MANAGER, testHdfsEnvironment, stats)) + .add(new OrcPageSourceFactory(TYPE_MANAGER, hiveClientConfig, testHdfsEnvironment, stats)) + .add(new DwrfPageSourceFactory(TYPE_MANAGER, testHdfsEnvironment, stats)) .add(new ParquetPageSourceFactory(TYPE_MANAGER, hiveClientConfig, testHdfsEnvironment)) .build(); } @@ -59,8 +70,6 @@ public static Set getDefaultHiveRecordCursorProvider(H HdfsEnvironment testHdfsEnvironment = createTestHdfsEnvironment(hiveClientConfig); return ImmutableSet.builder() .add(new ParquetRecordCursorProvider(hiveClientConfig, testHdfsEnvironment)) - .add(new ColumnarTextHiveRecordCursorProvider(testHdfsEnvironment)) - .add(new ColumnarBinaryHiveRecordCursorProvider(testHdfsEnvironment)) .add(new GenericHiveRecordCursorProvider(testHdfsEnvironment)) .build(); } @@ -69,7 +78,7 @@ public static Set getDefaultHiveFileWriterFactories(HiveC { HdfsEnvironment testHdfsEnvironment = createTestHdfsEnvironment(hiveClientConfig); return ImmutableSet.builder() - .add(new RcFileFileWriterFactory(testHdfsEnvironment, TYPE_MANAGER, new NodeVersion("test_version"), hiveClientConfig)) + .add(new RcFileFileWriterFactory(testHdfsEnvironment, TYPE_MANAGER, new NodeVersion("test_version"), hiveClientConfig, new FileFormatDataSourceStats())) .build(); } @@ -92,4 +101,11 @@ public static HdfsEnvironment createTestHdfsEnvironment(HiveClientConfig hiveCon HdfsConfiguration hdfsConfig = new HiveHdfsConfiguration(new HdfsConfigurationUpdater(hiveConfig, s3Config)); return new HdfsEnvironment(hdfsConfig, hiveConfig, new NoHdfsAuthentication()); } + + public static MapType mapType(Type keyType, Type valueType) + { + return (MapType) TYPE_MANAGER.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java index 19ff0a92c4bda..1cbc0987eb51b 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.hive.HiveClientConfig.HdfsAuthenticationType; +import com.facebook.presto.hive.HiveClientConfig.HiveMetastoreAuthenticationType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.net.HostAndPort; @@ -77,16 +79,12 @@ public void testDefaults() .setOrcMaxMergeDistance(new DataSize(1, Unit.MEGABYTE)) .setOrcMaxBufferSize(new DataSize(8, Unit.MEGABYTE)) .setOrcStreamBufferSize(new DataSize(8, Unit.MEGABYTE)) - .setRcfileOptimizedReaderEnabled(true) - .setRcfileOptimizedWriterEnabled(false) - .setHiveMetastoreAuthenticationType(HiveClientConfig.HiveMetastoreAuthenticationType.NONE) - .setHiveMetastoreServicePrincipal(null) - .setHiveMetastoreClientPrincipal(null) - .setHiveMetastoreClientKeytab(null) - .setHdfsAuthenticationType(HiveClientConfig.HdfsAuthenticationType.NONE) + .setOrcMaxReadBlockSize(new DataSize(16, Unit.MEGABYTE)) + .setRcfileOptimizedWriterEnabled(true) + .setRcfileWriterValidate(false) + .setHiveMetastoreAuthenticationType(HiveMetastoreAuthenticationType.NONE) + .setHdfsAuthenticationType(HdfsAuthenticationType.NONE) .setHdfsImpersonationEnabled(false) - .setHdfsPrestoPrincipal(null) - .setHdfsPrestoKeytab(null) .setSkipDeletionForAlter(false) .setBucketExecutionEnabled(true) .setBucketWritingEnabled(true) @@ -141,16 +139,12 @@ public void testExplicitPropertyMappings() .put("hive.orc.max-merge-distance", "22kB") .put("hive.orc.max-buffer-size", "44kB") .put("hive.orc.stream-buffer-size", "55kB") - .put("hive.rcfile-optimized-reader.enabled", "false") - .put("hive.rcfile-optimized-writer.enabled", "true") + .put("hive.orc.max-read-block-size", "66kB") + .put("hive.rcfile-optimized-writer.enabled", "false") + .put("hive.rcfile.writer.validate", "true") .put("hive.metastore.authentication.type", "KERBEROS") - .put("hive.metastore.service.principal", "hive/_HOST@EXAMPLE.COM") - .put("hive.metastore.client.principal", "metastore@EXAMPLE.COM") - .put("hive.metastore.client.keytab", "/tmp/metastore.keytab") .put("hive.hdfs.authentication.type", "KERBEROS") .put("hive.hdfs.impersonation.enabled", "true") - .put("hive.hdfs.presto.principal", "presto@EXAMPLE.COM") - .put("hive.hdfs.presto.keytab", "/tmp/presto.keytab") .put("hive.skip-deletion-for-alter", "true") .put("hive.bucket-execution", "false") .put("hive.bucket-writing", "false") @@ -202,16 +196,12 @@ public void testExplicitPropertyMappings() .setOrcMaxMergeDistance(new DataSize(22, Unit.KILOBYTE)) .setOrcMaxBufferSize(new DataSize(44, Unit.KILOBYTE)) .setOrcStreamBufferSize(new DataSize(55, Unit.KILOBYTE)) - .setRcfileOptimizedReaderEnabled(false) - .setRcfileOptimizedWriterEnabled(true) - .setHiveMetastoreAuthenticationType(HiveClientConfig.HiveMetastoreAuthenticationType.KERBEROS) - .setHiveMetastoreServicePrincipal("hive/_HOST@EXAMPLE.COM") - .setHiveMetastoreClientPrincipal("metastore@EXAMPLE.COM") - .setHiveMetastoreClientKeytab("/tmp/metastore.keytab") - .setHdfsAuthenticationType(HiveClientConfig.HdfsAuthenticationType.KERBEROS) + .setOrcMaxReadBlockSize(new DataSize(66, Unit.KILOBYTE)) + .setRcfileOptimizedWriterEnabled(false) + .setRcfileWriterValidate(true) + .setHiveMetastoreAuthenticationType(HiveMetastoreAuthenticationType.KERBEROS) + .setHdfsAuthenticationType(HdfsAuthenticationType.KERBEROS) .setHdfsImpersonationEnabled(true) - .setHdfsPrestoPrincipal("presto@EXAMPLE.COM") - .setHdfsPrestoKeytab("/tmp/presto.keytab") .setSkipDeletionForAlter(true) .setBucketExecutionEnabled(false) .setBucketWritingEnabled(false) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java index 3d9ba3e25fac9..2d28aa0345a9d 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java @@ -13,11 +13,13 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.Session; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.tests.AbstractTestDistributedQueries; import org.testng.annotations.Test; import static com.facebook.presto.hive.HiveQueryRunner.createQueryRunner; +import static com.facebook.presto.hive.HiveSessionProperties.RCFILE_OPTIMIZED_WRITER_ENABLED; import static com.facebook.presto.spi.type.CharType.createCharType; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.assertions.Assert.assertEquals; @@ -61,4 +63,65 @@ public void testOrderByChar() assertEquals(actual, expected); } + + /** + * Tests correctness of comparison of char(x) and varchar pushed down to a table scan as a TupleDomain + */ + @Test + public void testPredicatePushDownToTableScan() + throws Exception + { + // Test not specific to Hive, but needs a connector supporting table creation + + assertUpdate("CREATE TABLE test_table_with_char (a char(20))"); + try { + assertUpdate("INSERT INTO test_table_with_char (a) VALUES" + + "(cast('aaa' as char(20)))," + + "(cast('bbb' as char(20)))," + + "(cast('bbc' as char(20)))," + + "(cast('bbd' as char(20)))", 4); + + assertQuery( + "SELECT a, a <= 'bbc' FROM test_table_with_char", + "VALUES (cast('aaa' as char(20)), true), " + + "(cast('bbb' as char(20)), true), " + + "(cast('bbc' as char(20)), false), " + + "(cast('bbd' as char(20)), false)"); + + assertQuery( + "SELECT a FROM test_table_with_char WHERE a <= 'bbc'", + "VALUES cast('aaa' as char(20)), " + + "cast('bbb' as char(20))"); + } + finally { + assertUpdate("DROP TABLE test_table_with_char"); + } + } + + @Test + public void testRcTextCharDecoding() + throws Exception + { + testRcTextCharDecoding(false); + testRcTextCharDecoding(true); + } + + private void testRcTextCharDecoding(boolean rcFileOptimizedWriterEnabled) + throws Exception + { + String catalog = getSession().getCatalog().get(); + Session session = Session.builder(getSession()) + .setCatalogSessionProperty(catalog, RCFILE_OPTIMIZED_WRITER_ENABLED, Boolean.toString(rcFileOptimizedWriterEnabled)) + .build(); + + assertUpdate(session, "CREATE TABLE test_table_with_char_rc WITH (format = 'RCTEXT') AS SELECT CAST('khaki' AS CHAR(7)) char_column", 1); + try { + assertQuery(session, + "SELECT * FROM test_table_with_char_rc WHERE char_column = 'khaki '", + "VALUES (CAST('khaki' AS CHAR(7)))"); + } + finally { + assertUpdate(session, "DROP TABLE test_table_with_char_rc"); + } + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java index c1d573b5f0d50..4476381e335c9 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java @@ -24,14 +24,16 @@ import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.RecordPageSource; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.testing.TestingConnectorSession; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import io.airlift.compress.lzo.LzoCodec; +import io.airlift.compress.lzo.LzopCodec; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.common.type.HiveVarchar; @@ -97,6 +99,7 @@ public class TestHiveFileFormats extends AbstractTestHiveFileFormats { + private static final FileFormatDataSourceStats STATS = new FileFormatDataSourceStats(); private static TestingConnectorSession parquetCursorSession = new TestingConnectorSession(new HiveSessionProperties(new HiveClientConfig().setParquetOptimizedReaderEnabled(false).setParquetPredicatePushdownEnabled(false)).getSessionProperties()); private static TestingConnectorSession parquetCursorPushdownSession = new TestingConnectorSession(new HiveSessionProperties(new HiveClientConfig().setParquetOptimizedReaderEnabled(false).setParquetPredicatePushdownEnabled(true)).getSessionProperties()); private static TestingConnectorSession parquetPageSourceSession = new TestingConnectorSession(new HiveSessionProperties(new HiveClientConfig().setParquetOptimizedReaderEnabled(true).setParquetPredicatePushdownEnabled(false)).getSessionProperties()); @@ -177,7 +180,6 @@ public void testRCText(int rowCount) assertThatFileFormat(RCTEXT) .withColumns(testColumns) .withRowsCount(rowCount) - .isReadableByRecordCursor(new ColumnarTextHiveRecordCursorProvider(HDFS_ENVIRONMENT)) .isReadableByRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)); } @@ -185,14 +187,10 @@ public void testRCText(int rowCount) public void testRcTextPageSource(int rowCount) throws Exception { - TestingConnectorSession session = new TestingConnectorSession( - new HiveSessionProperties(new HiveClientConfig().setRcfileOptimizedReaderEnabled(true)).getSessionProperties()); - assertThatFileFormat(RCTEXT) .withColumns(TEST_COLUMNS) .withRowsCount(rowCount) - .withSession(session) - .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS)); } @Test(dataProvider = "rowCount") @@ -205,16 +203,15 @@ public void testRcTextOptimizedWriter(int rowCount) .collect(toImmutableList()); TestingConnectorSession session = new TestingConnectorSession( - new HiveSessionProperties(new HiveClientConfig().setRcfileOptimizedWriterEnabled(true).setRcfileOptimizedReaderEnabled(true)).getSessionProperties()); + new HiveSessionProperties(new HiveClientConfig().setRcfileOptimizedWriterEnabled(true)).getSessionProperties()); assertThatFileFormat(RCTEXT) .withColumns(testColumns) .withRowsCount(rowCount) .withSession(session) - .withFileWriterFactory(new RcFileFileWriterFactory(HDFS_ENVIRONMENT, TYPE_MANAGER, new NodeVersion("test"), HIVE_STORAGE_TIME_ZONE)) - .isReadableByRecordCursor(new ColumnarTextHiveRecordCursorProvider(HDFS_ENVIRONMENT)) + .withFileWriterFactory(new RcFileFileWriterFactory(HDFS_ENVIRONMENT, TYPE_MANAGER, new NodeVersion("test"), HIVE_STORAGE_TIME_ZONE, STATS)) .isReadableByRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)) - .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS)); } @Test(dataProvider = "rowCount") @@ -230,7 +227,6 @@ public void testRCBinary(int rowCount) assertThatFileFormat(RCBINARY) .withColumns(testColumns) .withRowsCount(rowCount) - .isReadableByRecordCursor(new ColumnarBinaryHiveRecordCursorProvider(HDFS_ENVIRONMENT)) .isReadableByRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)); } @@ -243,14 +239,10 @@ public void testRcBinaryPageSource(int rowCount) .filter(testColumn -> !testColumn.getName().equals("t_empty_varchar")) .collect(toList()); - TestingConnectorSession session = new TestingConnectorSession( - new HiveSessionProperties(new HiveClientConfig().setRcfileOptimizedReaderEnabled(true)).getSessionProperties()); - assertThatFileFormat(RCBINARY) .withColumns(testColumns) .withRowsCount(rowCount) - .withSession(session) - .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS)); } @Test(dataProvider = "rowCount") @@ -265,16 +257,15 @@ public void testRcBinaryOptimizedWriter(int rowCount) .collect(toList()); TestingConnectorSession session = new TestingConnectorSession( - new HiveSessionProperties(new HiveClientConfig().setRcfileOptimizedWriterEnabled(true).setRcfileOptimizedReaderEnabled(true)).getSessionProperties()); + new HiveSessionProperties(new HiveClientConfig().setRcfileOptimizedWriterEnabled(true)).getSessionProperties()); assertThatFileFormat(RCBINARY) .withColumns(testColumns) .withRowsCount(rowCount) .withSession(session) - .withFileWriterFactory(new RcFileFileWriterFactory(HDFS_ENVIRONMENT, TYPE_MANAGER, new NodeVersion("test"), HIVE_STORAGE_TIME_ZONE)) - .isReadableByRecordCursor(new ColumnarBinaryHiveRecordCursorProvider(HDFS_ENVIRONMENT)) + .withFileWriterFactory(new RcFileFileWriterFactory(HDFS_ENVIRONMENT, TYPE_MANAGER, new NodeVersion("test"), HIVE_STORAGE_TIME_ZONE, STATS)) .isReadableByRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)) - .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS)); } @Test(dataProvider = "rowCount") @@ -284,7 +275,7 @@ public void testOrc(int rowCount) assertThatFileFormat(ORC) .withColumns(TEST_COLUMNS) .withRowsCount(rowCount) - .isReadableByPageSource(new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT, STATS)); } @Test(dataProvider = "rowCount") @@ -298,7 +289,7 @@ public void testOrcUseColumnNames(int rowCount) .withRowsCount(rowCount) .withReadColumns(Lists.reverse(TEST_COLUMNS)) .withSession(session) - .isReadableByPageSource(new OrcPageSourceFactory(TYPE_MANAGER, true, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new OrcPageSourceFactory(TYPE_MANAGER, true, HDFS_ENVIRONMENT, STATS)); } @Test(dataProvider = "rowCount") @@ -501,7 +492,7 @@ public void testDwrf(int rowCount) assertThatFileFormat(DWRF) .withColumns(testColumns) .withRowsCount(rowCount) - .isReadableByPageSource(new DwrfPageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new DwrfPageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS)); } @Test @@ -514,19 +505,19 @@ public void testTruncateVarcharColumn() assertThatFileFormat(RCTEXT) .withWriteColumns(ImmutableList.of(writeColumn)) .withReadColumns(ImmutableList.of(readColumn)) - .isReadableByRecordCursor(new ColumnarTextHiveRecordCursorProvider(HDFS_ENVIRONMENT)) + .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS)) .isReadableByRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)); assertThatFileFormat(RCBINARY) .withWriteColumns(ImmutableList.of(writeColumn)) .withReadColumns(ImmutableList.of(readColumn)) - .isReadableByRecordCursor(new ColumnarBinaryHiveRecordCursorProvider(HDFS_ENVIRONMENT)) + .isReadableByPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS)) .isReadableByRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)); assertThatFileFormat(ORC) .withWriteColumns(ImmutableList.of(writeColumn)) .withReadColumns(ImmutableList.of(readColumn)) - .isReadableByPageSource(new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT)); + .isReadableByPageSource(new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT, STATS)); assertThatFileFormat(PARQUET) .withWriteColumns(ImmutableList.of(writeColumn)) @@ -580,17 +571,17 @@ public void testFailForLongVarcharPartitionColumn() assertThatFileFormat(RCTEXT) .withColumns(columns) - .isFailingForRecordCursor(new ColumnarTextHiveRecordCursorProvider(HDFS_ENVIRONMENT), expectedErrorCode, expectedMessage) + .isFailingForPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS), expectedErrorCode, expectedMessage) .isFailingForRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT), expectedErrorCode, expectedMessage); assertThatFileFormat(RCBINARY) .withColumns(columns) - .isFailingForRecordCursor(new ColumnarBinaryHiveRecordCursorProvider(HDFS_ENVIRONMENT), expectedErrorCode, expectedMessage) + .isFailingForPageSource(new RcFilePageSourceFactory(TYPE_MANAGER, HDFS_ENVIRONMENT, STATS), expectedErrorCode, expectedMessage) .isFailingForRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT), expectedErrorCode, expectedMessage); assertThatFileFormat(ORC) .withColumns(columns) - .isFailingForPageSource(new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT), expectedErrorCode, expectedMessage); + .isFailingForPageSource(new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT, STATS), expectedErrorCode, expectedMessage); assertThatFileFormat(PARQUET) .withColumns(columns) @@ -637,11 +628,13 @@ private void testCursorProvider(HiveRecordCursorProvider cursorProvider, .map(input -> new HivePartitionKey(input.getName(), HiveType.valueOf(input.getObjectInspector().getTypeName()), (String) input.getWriteValue())) .collect(toList()); + Configuration configuration = new Configuration(); + configuration.set("io.compression.codecs", LzoCodec.class.getName() + "," + LzopCodec.class.getName()); Optional pageSource = HivePageSourceProvider.createHivePageSource( ImmutableSet.of(cursorProvider), ImmutableSet.of(), "test", - new Configuration(), + configuration, SESSION, split.getPath(), OptionalInt.empty(), @@ -853,7 +846,18 @@ private void assertRead(Optional pageSourceFactory, Optio assertNotNull(session, "session must be specified"); assertTrue(rowsCount >= 0, "rowsCount must be greater than zero"); - File file = File.createTempFile("presto_test", formatName); + String compressionSuffix = compressionCodec.getCodec() + .map(codec -> { + try { + return codec.getConstructor().newInstance().getDefaultExtension(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + }) + .orElse(""); + + File file = File.createTempFile("presto_test", formatName + compressionSuffix); file.delete(); try { FileSplit split; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java index da7b5f39df169..e64199d3ac7c1 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java @@ -1146,6 +1146,78 @@ private void testInsertPartitionedTableExistingPartition(Session session, HiveSt assertFalse(getQueryRunner().tableExists(session, tableName)); } + @Test + public void testPartitionPerScanLimit() + throws Exception + { + TestingHiveStorageFormat storageFormat = new TestingHiveStorageFormat(getSession(), HiveStorageFormat.DWRF); + testPartitionPerScanLimit(storageFormat.getSession(), storageFormat.getFormat()); + } + + public void testPartitionPerScanLimit(Session session, HiveStorageFormat storageFormat) + throws Exception + { + String tableName = "test_partition_per_scan_limit"; + + @Language("SQL") String createTable = "" + + "CREATE TABLE " + tableName + " " + + "(" + + " foo VARCHAR," + + " part BIGINT" + + ") " + + "WITH (" + + "format = '" + storageFormat + "', " + + "partitioned_by = ARRAY[ 'part' ]" + + ") "; + + assertUpdate(session, createTable); + + TableMetadata tableMetadata = getTableMetadata(catalog, TPCH_SCHEMA, tableName); + assertEquals(tableMetadata.getMetadata().getProperties().get(STORAGE_FORMAT_PROPERTY), storageFormat); + assertEquals(tableMetadata.getMetadata().getProperties().get(PARTITIONED_BY_PROPERTY), ImmutableList.of("part")); + + // insert 1200 partitions + for (int i = 0; i < 12; i++) { + int partStart = i * 100; + int partEnd = (i + 1) * 100 - 1; + + @Language("SQL") String insertPartitions = "" + + "INSERT INTO " + tableName + " " + + "SELECT 'bar' foo, part " + + "FROM UNNEST(SEQUENCE(" + partStart + ", " + partEnd + ")) AS TMP(part)"; + + assertUpdate(session, insertPartitions, 100); + } + + // verify can query 1000 partitions + assertQuery( + session, + "SELECT count(foo) FROM " + tableName + " WHERE part < 1000", + "SELECT 1000"); + + // verify the rest 200 partitions are successfully inserted + assertQuery( + session, + "SELECT count(foo) FROM " + tableName + " WHERE part >= 1000 AND part < 1200", + "SELECT 200"); + + // verify cannot query more than 1000 partitions + assertQueryFails( + session, + "SELECT * from " + tableName + " WHERE part < 1001", + format("Query over table 'tpch.%s' can potentially read more than 1000 partitions", tableName)); + + // verify cannot query all partitions + assertQueryFails( + session, + "SELECT * from " + tableName, + format("Query over table 'tpch.%s' can potentially read more than 1000 partitions", tableName)); + + assertUpdate(session, "DROP TABLE " + tableName); + + assertFalse(getQueryRunner().tableExists(session, tableName)); + } + @Test public void testInsertUnpartitionedTable() throws Exception @@ -1868,12 +1940,6 @@ private List getAllTestingHiveStorageFormat() for (HiveStorageFormat hiveStorageFormat : HiveStorageFormat.values()) { formats.add(new TestingHiveStorageFormat(session, hiveStorageFormat)); } - formats.add(new TestingHiveStorageFormat( - Session.builder(session).setCatalogSessionProperty(session.getCatalog().get(), "rcfile_optimized_reader_enabled", "true").build(), - HiveStorageFormat.RCBINARY)); - formats.add(new TestingHiveStorageFormat( - Session.builder(session).setCatalogSessionProperty(session.getCatalog().get(), "rcfile_optimized_reader_enabled", "true").build(), - HiveStorageFormat.RCTEXT)); formats.add(new TestingHiveStorageFormat( Session.builder(session).setCatalogSessionProperty(session.getCatalog().get(), "parquet_optimized_reader_enabled", "true").build(), HiveStorageFormat.PARQUET)); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java index b5737a597bc17..56c32279d4e59 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java @@ -144,7 +144,7 @@ public void run() hiveSplitSource.addToQueue(new TestSplit(33)); // wait for thread to get the split - ConnectorSplit split = splits.get(200, TimeUnit.MILLISECONDS); + ConnectorSplit split = splits.get(800, TimeUnit.MILLISECONDS); assertSame(split.getInfo(), 33); } finally { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestNumberParser.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestNumberParser.java deleted file mode 100644 index 14023b0165447..0000000000000 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestNumberParser.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import org.testng.annotations.Test; - -import static com.facebook.presto.hive.NumberParser.parseDouble; -import static com.facebook.presto.hive.NumberParser.parseLong; -import static java.nio.charset.StandardCharsets.US_ASCII; -import static org.testng.Assert.assertEquals; - -public class TestNumberParser -{ - @Test - public void testLong() - throws Exception - { - assertParseLong("1"); - assertParseLong("12"); - assertParseLong("123"); - assertParseLong("-1"); - assertParseLong("-12"); - assertParseLong("-123"); - assertParseLong("+1"); - assertParseLong("+12"); - assertParseLong("+123"); - assertParseLong("0"); - assertParseLong("-0"); - assertParseLong("+0"); - assertParseLong(Long.toString(Long.MAX_VALUE)); - assertParseLong(Long.toString(Long.MIN_VALUE)); - } - - @Test - public void testDouble() - throws Exception - { - assertParseDouble("123"); - assertParseDouble("123.0"); - assertParseDouble("123.456"); - assertParseDouble("123.456e5"); - assertParseDouble("123.456e-5"); - assertParseDouble("123e5"); - assertParseDouble("123e-5"); - assertParseDouble("0"); - assertParseDouble("0.0"); - assertParseDouble("0.456"); - assertParseDouble("-0"); - assertParseDouble("-0.0"); - assertParseDouble("-0.456"); - assertParseDouble("-123"); - assertParseDouble("-123.0"); - assertParseDouble("-123.456"); - assertParseDouble("-123.456e-5"); - assertParseDouble("-123e5"); - assertParseDouble("-123e-5"); - assertParseDouble("+123"); - assertParseDouble("+123.0"); - assertParseDouble("+123.456"); - assertParseDouble("+123.456e5"); - assertParseDouble("+123.456e-5"); - assertParseDouble("+123e5"); - assertParseDouble("+123e-5"); - assertParseDouble("+0"); - assertParseDouble("+0.0"); - assertParseDouble("+0.456"); - - assertParseDouble("NaN"); - assertParseDouble("-Infinity"); - assertParseDouble("Infinity"); - assertParseDouble("+Infinity"); - - assertParseDouble(Double.toString(Double.MAX_VALUE)); - assertParseDouble(Double.toString(-Double.MAX_VALUE)); - assertParseDouble(Double.toString(Double.MIN_VALUE)); - assertParseDouble(Double.toString(-Double.MIN_VALUE)); - } - - private static void assertParseLong(String string) - { - assertEquals(parseLong(string.getBytes(US_ASCII), 0, string.length()), Long.parseLong(string)); - - // verify we can parse using a non-zero offset - String padding = "9999"; - String padded = padding + string + padding; - assertEquals(parseLong(padded.getBytes(US_ASCII), padding.length(), string.length()), Long.parseLong(string)); - } - - private static void assertParseDouble(String string) - { - assertEquals(parseDouble(string.getBytes(US_ASCII), 0, string.length()), Double.parseDouble(string)); - - // verify we can parse using a non-zero offset - String padding = "9999"; - String padded = padding + string + padding; - assertEquals(parseDouble(padded.getBytes(US_ASCII), padding.length(), string.length()), Double.parseDouble(string)); - } -} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java index aaeb869a7e5ac..6afdd67ae0ee8 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java @@ -25,6 +25,7 @@ import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; @@ -33,6 +34,7 @@ import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.relational.RowExpression; +import com.facebook.presto.testing.TestingConnectorSession; import com.facebook.presto.testing.TestingSplit; import com.facebook.presto.testing.TestingTransactionHandle; import com.google.common.base.Joiner; @@ -41,6 +43,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; +import io.airlift.stats.Distribution; +import io.airlift.units.DataSize; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -67,6 +71,7 @@ import org.joda.time.DateTimeZone; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.io.File; @@ -74,6 +79,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -89,6 +95,7 @@ import static com.facebook.presto.hive.HiveTestUtils.SESSION; import static com.facebook.presto.hive.HiveTestUtils.TYPE_MANAGER; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.orc.OrcReader.MAX_BATCH_SIZE; import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -98,6 +105,7 @@ import static com.google.common.collect.Iterables.transform; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertBetweenInclusive; +import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.stream.Collectors.toList; @@ -132,6 +140,12 @@ public class TestOrcPageSourceMemoryTracking private File tempFile; private TestPreparer testPreparer; + @DataProvider(name = "rowCount") + public static Object[][] rowCount() + { + return new Object[][] { { 50_000 }, { 10_000 }, { 5_000 } }; + } + @BeforeClass public void setUp() throws Exception @@ -155,7 +169,8 @@ public void testPageSource() // Numbers used in assertions in this test may change when implementation is modified, // feel free to change them if they break in the future - ConnectorPageSource pageSource = testPreparer.newPageSource(); + FileFormatDataSourceStats stats = new FileFormatDataSourceStats(); + ConnectorPageSource pageSource = testPreparer.newPageSource(stats); assertEquals(pageSource.getSystemMemoryUsage(), 0); @@ -221,6 +236,64 @@ public void testPageSource() assertTrue(pageSource.isFinished()); assertEquals(pageSource.getSystemMemoryUsage(), 0); pageSource.close(); + assertEquals((int) stats.getLoadedBlockBytes().getAllTime().getCount(), 50); + } + + @Test(dataProvider = "rowCount") + public void testMaxReadBytes(int rowCount) + throws Exception + { + int maxReadBytes = 1_000; + HiveClientConfig config = new HiveClientConfig(); + config.setOrcMaxReadBlockSize(new DataSize(maxReadBytes, BYTE)); + ConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(config).getSessionProperties()); + FileFormatDataSourceStats stats = new FileFormatDataSourceStats(); + + // Build a table where every row gets larger, so we can test that the "batchSize" reduces + int numColumns = 5; + int step = 250; + ImmutableList.Builder columnBuilder = ImmutableList.builder() + .add(new TestColumn("p_empty_string", javaStringObjectInspector, () -> "", true)); + GrowingTestColumn[] dataColumns = new GrowingTestColumn[numColumns]; + for (int i = 0; i < numColumns; i++) { + dataColumns[i] = new GrowingTestColumn("p_string", javaStringObjectInspector, () -> Long.toHexString(random.nextLong()), false, step * (i + 1)); + columnBuilder.add(dataColumns[i]); + } + List testColumns = columnBuilder.build(); + File tempFile = File.createTempFile("presto_test_orc_page_source_max_read_bytes", "orc"); + tempFile.delete(); + + TestPreparer testPreparer = new TestPreparer(tempFile.getAbsolutePath(), testColumns, rowCount, rowCount); + ConnectorPageSource pageSource = testPreparer.newPageSource(stats, session); + + try { + int positionCount = 0; + while (true) { + Page page = pageSource.getNextPage(); + if (pageSource.isFinished()) { + break; + } + assertNotNull(page); + page.assureLoaded(); + positionCount += page.getPositionCount(); + // assert upper bound is tight + // ignore the first MAX_BATCH_SIZE rows given the sizes are set when loading the blocks + if (positionCount > MAX_BATCH_SIZE) { + // either the block is bounded by maxReadBytes or we just load one single large block + // an error margin MAX_BATCH_SIZE / step is needed given the block sizes are increasing + assertTrue(page.getSizeInBytes() < maxReadBytes * (MAX_BATCH_SIZE / step) || 1 == page.getPositionCount()); + } + } + + // verify the stats are correctly recorded + Distribution distribution = stats.getMaxCombinedBytesPerRow().getAllTime(); + assertEquals((int) distribution.getCount(), 1); + assertEquals((int) distribution.getMax(), Arrays.stream(dataColumns).mapToInt(GrowingTestColumn::getMaxSize).sum()); + pageSource.close(); + } + finally { + tempFile.delete(); + } } @Test @@ -321,6 +394,12 @@ private class TestPreparer public TestPreparer(String tempFilePath) throws Exception + { + this(tempFilePath, testColumns, NUM_ROWS, STRIPE_ROWS); + } + + public TestPreparer(String tempFilePath, List testColumns, int numRows, int stripeRows) + throws Exception { OrcSerde serde = new OrcSerde(); schema = new Properties(); @@ -357,18 +436,28 @@ public TestPreparer(String tempFilePath) columns = columnsBuilder.build(); types = typesBuilder.build(); - fileSplit = createTestFile(tempFilePath, new OrcOutputFormat(), serde, null, testColumns, NUM_ROWS); + fileSplit = createTestFile(tempFilePath, new OrcOutputFormat(), serde, null, testColumns, numRows, stripeRows); } public ConnectorPageSource newPageSource() { - OrcPageSourceFactory orcPageSourceFactory = new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT); + return newPageSource(new FileFormatDataSourceStats(), SESSION); + } + + public ConnectorPageSource newPageSource(FileFormatDataSourceStats stats) + { + return newPageSource(stats, SESSION); + } + + public ConnectorPageSource newPageSource(FileFormatDataSourceStats stats, ConnectorSession session) + { + OrcPageSourceFactory orcPageSourceFactory = new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT, stats); return HivePageSourceProvider.createHivePageSource( ImmutableSet.of(), ImmutableSet.of(orcPageSourceFactory), "test", new Configuration(), - SESSION, + session, fileSplit.getPath(), OptionalInt.empty(), fileSplit.getStart(), @@ -435,7 +524,8 @@ public static FileSplit createTestFile(String filePath, @SuppressWarnings("deprecation") SerDe serDe, String compressionCodec, List testColumns, - int numRows) + int numRows, + int stripeRows) throws Exception { // filter out partition keys, which are not written to the file @@ -475,7 +565,7 @@ public static FileSplit createTestFile(String filePath, Writable record = serDe.serialize(row, objectInspector); recordWriter.write(record); - if (rowNumber % STRIPE_ROWS == STRIPE_ROWS - 1) { + if (rowNumber % stripeRows == stripeRows - 1) { flushStripe(recordWriter); } } @@ -539,7 +629,7 @@ private static Constructor getOrcWriterConstructor() } } - public static final class TestColumn + public static class TestColumn { private final String name; private final ObjectInspector objectInspector; @@ -590,4 +680,41 @@ public String toString() return sb.toString(); } } + + public static final class GrowingTestColumn + extends TestColumn + { + private final Supplier writeValue; + private int counter; + private int step; + private int maxSize; + + public GrowingTestColumn(String name, ObjectInspector objectInspector, Supplier writeValue, boolean partitionKey, int step) + { + super(name, objectInspector, writeValue, partitionKey); + this.writeValue = writeValue; + this.counter = step; + this.step = step; + } + + @Override + public Object getWriteValue() + { + StringBuilder builder = new StringBuilder(); + String source = writeValue.get(); + for (int i = 0; i < counter / step; i++) { + builder.append(source); + } + counter++; + if (builder.length() > maxSize) { + maxSize = builder.length(); + } + return builder.toString(); + } + + public int getMaxSize() + { + return maxSize; + } + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/authentication/TestHdfsKerberosConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/authentication/TestHdfsKerberosConfig.java new file mode 100644 index 0000000000000..74a7e4ccb5497 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/authentication/TestHdfsKerberosConfig.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.authentication; + +import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.testing.ConfigAssertions; +import org.testng.annotations.Test; + +import java.util.Map; + +public class TestHdfsKerberosConfig +{ + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("hive.hdfs.presto.principal", "presto@EXAMPLE.COM") + .put("hive.hdfs.presto.keytab", "/tmp/presto.keytab") + .build(); + + HdfsKerberosConfig expected = new HdfsKerberosConfig() + .setHdfsPrestoPrincipal("presto@EXAMPLE.COM") + .setHdfsPrestoKeytab("/tmp/presto.keytab"); + + ConfigAssertions.assertFullMapping(properties, expected); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/authentication/TestMetastoreKerberosConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/authentication/TestMetastoreKerberosConfig.java new file mode 100644 index 0000000000000..460fae5534f31 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/authentication/TestMetastoreKerberosConfig.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.authentication; + +import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.testing.ConfigAssertions; +import org.testng.annotations.Test; + +import java.util.Map; + +public class TestMetastoreKerberosConfig +{ + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("hive.metastore.service.principal", "hive/_HOST@EXAMPLE.COM") + .put("hive.metastore.client.principal", "metastore@EXAMPLE.COM") + .put("hive.metastore.client.keytab", "/tmp/metastore.keytab") + .build(); + + MetastoreKerberosConfig expected = new MetastoreKerberosConfig() + .setHiveMetastoreServicePrincipal("hive/_HOST@EXAMPLE.COM") + .setHiveMetastoreClientPrincipal("metastore@EXAMPLE.COM") + .setHiveMetastoreClientKeytab("/tmp/metastore.keytab"); + + ConfigAssertions.assertFullMapping(properties, expected); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java index 2a7df8715e252..703a6a376ca77 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.hive.benchmark; -import com.facebook.presto.hive.ColumnarBinaryHiveRecordCursorProvider; -import com.facebook.presto.hive.ColumnarTextHiveRecordCursorProvider; +import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.GenericHiveRecordCursorProvider; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveColumnHandle; @@ -78,7 +77,7 @@ public enum FileFormat @Override public ConnectorPageSource createFileFormatReader(ConnectorSession session, HdfsEnvironment hdfsEnvironment, File targetFile, List columnNames, List columnTypes) { - HivePageSourceFactory pageSourceFactory = new RcFilePageSourceFactory(TYPE_MANAGER, hdfsEnvironment); + HivePageSourceFactory pageSourceFactory = new RcFilePageSourceFactory(TYPE_MANAGER, hdfsEnvironment, new FileFormatDataSourceStats()); return createPageSource(pageSourceFactory, session, targetFile, columnNames, columnTypes, HiveStorageFormat.RCBINARY); } @@ -103,7 +102,7 @@ public FormatWriter createFileFormatWriter( @Override public ConnectorPageSource createFileFormatReader(ConnectorSession session, HdfsEnvironment hdfsEnvironment, File targetFile, List columnNames, List columnTypes) { - HivePageSourceFactory pageSourceFactory = new RcFilePageSourceFactory(TYPE_MANAGER, hdfsEnvironment); + HivePageSourceFactory pageSourceFactory = new RcFilePageSourceFactory(TYPE_MANAGER, hdfsEnvironment, new FileFormatDataSourceStats()); return createPageSource(pageSourceFactory, session, targetFile, columnNames, columnTypes, HiveStorageFormat.RCTEXT); } @@ -128,7 +127,7 @@ public FormatWriter createFileFormatWriter( @Override public ConnectorPageSource createFileFormatReader(ConnectorSession session, HdfsEnvironment hdfsEnvironment, File targetFile, List columnNames, List columnTypes) { - HivePageSourceFactory pageSourceFactory = new OrcPageSourceFactory(TYPE_MANAGER, false, hdfsEnvironment); + HivePageSourceFactory pageSourceFactory = new OrcPageSourceFactory(TYPE_MANAGER, false, hdfsEnvironment, new FileFormatDataSourceStats()); return createPageSource(pageSourceFactory, session, targetFile, columnNames, columnTypes, HiveStorageFormat.ORC); } @@ -149,7 +148,7 @@ public FormatWriter createFileFormatWriter( @Override public ConnectorPageSource createFileFormatReader(ConnectorSession session, HdfsEnvironment hdfsEnvironment, File targetFile, List columnNames, List columnTypes) { - HivePageSourceFactory pageSourceFactory = new DwrfPageSourceFactory(TYPE_MANAGER, hdfsEnvironment); + HivePageSourceFactory pageSourceFactory = new DwrfPageSourceFactory(TYPE_MANAGER, hdfsEnvironment, new FileFormatDataSourceStats()); return createPageSource(pageSourceFactory, session, targetFile, columnNames, columnTypes, HiveStorageFormat.DWRF); } @@ -197,7 +196,7 @@ public FormatWriter createFileFormatWriter( @Override public ConnectorPageSource createFileFormatReader(ConnectorSession session, HdfsEnvironment hdfsEnvironment, File targetFile, List columnNames, List columnTypes) { - HiveRecordCursorProvider cursorProvider = new ColumnarBinaryHiveRecordCursorProvider(hdfsEnvironment); + HiveRecordCursorProvider cursorProvider = new GenericHiveRecordCursorProvider(hdfsEnvironment); return createPageSource(cursorProvider, session, targetFile, columnNames, columnTypes, HiveStorageFormat.RCBINARY); } @@ -218,7 +217,7 @@ public FormatWriter createFileFormatWriter( @Override public ConnectorPageSource createFileFormatReader(ConnectorSession session, HdfsEnvironment hdfsEnvironment, File targetFile, List columnNames, List columnTypes) { - HiveRecordCursorProvider cursorProvider = new ColumnarTextHiveRecordCursorProvider(hdfsEnvironment); + HiveRecordCursorProvider cursorProvider = new GenericHiveRecordCursorProvider(hdfsEnvironment); return createPageSource(cursorProvider, session, targetFile, columnNames, columnTypes, HiveStorageFormat.RCTEXT); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java index 46df5cb2fed82..fb32e0b9aa150 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java @@ -23,10 +23,9 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.testing.TestingConnectorSession; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import io.airlift.tpch.OrderColumn; @@ -62,6 +61,7 @@ import java.util.concurrent.TimeUnit; import static com.facebook.presto.hive.HiveTestUtils.createTestHdfsEnvironment; +import static com.facebook.presto.hive.HiveTestUtils.mapType; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DateType.DATE; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; @@ -91,7 +91,6 @@ public class HiveFileFormatBenchmark @SuppressWarnings("deprecation") private static final HiveClientConfig CONFIG = new HiveClientConfig() - .setRcfileOptimizedReaderEnabled(true) .setParquetOptimizedReaderEnabled(true); private static final ConnectorSession SESSION = new TestingConnectorSession(new HiveSessionProperties(CONFIG) @@ -282,7 +281,7 @@ public TestData createTestData(FileFormat format) @Override public TestData createTestData(FileFormat format) { - Type type = new MapType(createUnboundedVarcharType(), DOUBLE); + Type type = mapType(createUnboundedVarcharType(), DOUBLE); Random random = new Random(1234); PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type)); @@ -321,7 +320,7 @@ public TestData createTestData(FileFormat format) @Override public TestData createTestData(FileFormat format) { - Type type = new MapType(createUnboundedVarcharType(), DOUBLE); + Type type = mapType(createUnboundedVarcharType(), DOUBLE); Random random = new Random(1234); PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type)); @@ -356,7 +355,7 @@ public TestData createTestData(FileFormat format) @Override public TestData createTestData(FileFormat format) { - Type type = new MapType(INTEGER, DOUBLE); + Type type = mapType(INTEGER, DOUBLE); Random random = new Random(1234); PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type)); @@ -395,7 +394,7 @@ public TestData createTestData(FileFormat format) @Override public TestData createTestData(FileFormat format) { - Type type = new MapType(INTEGER, DOUBLE); + Type type = mapType(INTEGER, DOUBLE); Random random = new Random(1234); PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type)); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/InMemoryHiveMetastore.java b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/InMemoryHiveMetastore.java index cf14e31e7a685..99c14c70e631a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/InMemoryHiveMetastore.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/InMemoryHiveMetastore.java @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.common.FileUtils; import org.apache.hadoop.hive.metastore.TableType; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.Partition; @@ -79,6 +80,10 @@ public class InMemoryHiveMetastore @GuardedBy("this") private final Map partitions = new HashMap<>(); @GuardedBy("this") + private final Map> columnStatistics = new HashMap<>(); + @GuardedBy("this") + private final Map> partitionColumnStatistics = new HashMap<>(); + @GuardedBy("this") private final Map> roleGrants = new HashMap<>(); @GuardedBy("this") private final Map> tablePrivileges = new HashMap<>(); @@ -431,6 +436,58 @@ public synchronized Optional
getTable(String databaseName, String tableNa return Optional.ofNullable(relations.get(schemaTableName)); } + @Override + public synchronized Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames) + { + SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName); + if (!columnStatistics.containsKey(schemaTableName)) { + return Optional.empty(); + } + + Map columnStatisticsMap = columnStatistics.get(schemaTableName); + return Optional.of(columnNames.stream() + .filter(columnStatisticsMap::containsKey) + .map(columnStatisticsMap::get) + .collect(toImmutableSet())); + } + + public synchronized void setColumnStatistics(String databaseName, String tableName, String columnName, ColumnStatisticsObj columnStatisticsObj) + { + checkArgument(columnStatisticsObj.getColName().equals(columnName), "columnName argument and columnStatisticsObj.getColName() must be the same"); + SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName); + columnStatistics.computeIfAbsent(schemaTableName, key -> new HashMap<>()).put(columnName, columnStatisticsObj); + } + + @Override + public synchronized Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames) + { + ImmutableMap.Builder> result = ImmutableMap.builder(); + for (String partitionName : partitionNames) { + PartitionName partitionKey = PartitionName.partition(databaseName, tableName, partitionName); + if (!partitionColumnStatistics.containsKey(partitionKey)) { + continue; + } + + Map columnStatistics = partitionColumnStatistics.get(partitionKey); + result.put( + partitionName, + columnNames.stream() + .filter(columnStatistics::containsKey) + .map(columnStatistics::get) + .collect(toImmutableSet())); + } + return Optional.of(result.build()); + } + + public synchronized void setPartitionColumnStatistics(String databaseName, String tableName, String partitionName, String columnName, ColumnStatisticsObj columnStatisticsObj) + { + checkArgument(columnStatisticsObj.getColName().equals(columnName), "columnName argument and columnStatisticsObj.getColName() must be the same"); + PartitionName partitionKey = PartitionName.partition(databaseName, tableName, partitionName); + partitionColumnStatistics + .computeIfAbsent(partitionKey, key -> new HashMap<>()) + .put(columnName, columnStatisticsObj); + } + @Override public synchronized Set getRoles(String user) { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/MockHiveMetastoreClient.java b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/MockHiveMetastoreClient.java index cf7ec82c8b20e..c087e6ce1bcc2 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/MockHiveMetastoreClient.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/MockHiveMetastoreClient.java @@ -20,6 +20,7 @@ import com.google.common.collect.Lists; import org.apache.hadoop.hive.metastore.TableType; import org.apache.hadoop.hive.metastore.Warehouse; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.HiveObjectPrivilege; @@ -37,6 +38,7 @@ import org.apache.thrift.TException; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; public class MockHiveMetastoreClient @@ -131,6 +133,20 @@ public Table getTable(String dbName, String tableName) TableType.MANAGED_TABLE.name()); } + @Override + public List getTableColumnStatistics(String databaseName, String tableName, List columnNames) + throws TException + { + throw new UnsupportedOperationException(); + } + + @Override + public Map> getPartitionColumnStatistics(String databaseName, String tableName, List columnNames, List partitionValues) + throws TException + { + throw new UnsupportedOperationException(); + } + @Override public List getTableNamesByFilter(String databaseName, String filter) { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java index 7b808de9cf168..f1a2503f50a6a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestingHiveMetastore.java @@ -494,6 +494,18 @@ public synchronized Optional
getTable(String databaseName, String tableNa return Optional.ofNullable(relations.get(schemaTableName)); } + @Override + public Optional> getTableColumnStatistics(String databaseName, String tableName, Set columnNames) + { + return Optional.of(ImmutableMap.of()); + } + + @Override + public Optional>> getPartitionColumnStatistics(String databaseName, String tableName, Set partitionNames, Set columnNames) + { + return Optional.of(ImmutableMap.of()); + } + private synchronized Table getRequiredTable(SchemaTableName tableName) { Table oldTable = relations.get(tableName); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java index 910a5f34eb321..401bc62646109 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSerDeUtils.java @@ -18,9 +18,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.reflect.TypeToken; @@ -42,6 +41,7 @@ import java.util.Optional; import java.util.TreeMap; +import static com.facebook.presto.hive.HiveTestUtils.mapType; import static com.facebook.presto.hive.util.SerDeUtils.getBlockObject; import static com.facebook.presto.hive.util.SerDeUtils.serializeObject; import static com.facebook.presto.spi.type.BigintType.BIGINT; @@ -201,7 +201,7 @@ public void testMapBlock() holder.map.put("fifteen", new InnerStruct(16, 17L)); com.facebook.presto.spi.type.Type rowType = new RowType(ImmutableList.of(INTEGER, BIGINT), Optional.empty()); - com.facebook.presto.spi.type.Type mapOfVarcharRowType = new RowType(ImmutableList.of(new MapType(createUnboundedVarcharType(), rowType)), Optional.empty()); + com.facebook.presto.spi.type.Type mapOfVarcharRowType = new RowType(ImmutableList.of(mapType(createUnboundedVarcharType(), rowType)), Optional.empty()); Block actual = toBinaryBlock(mapOfVarcharRowType, holder, getInspector(MapHolder.class)); BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(createUnboundedVarcharType(), rowType), new BlockBuilderStatus(), 1024); @@ -209,7 +209,7 @@ public void testMapBlock() rowType.writeObject(blockBuilder, rowBlockOf(ImmutableList.of(INTEGER, BIGINT), 16, 17L)); createUnboundedVarcharType().writeString(blockBuilder, "twelve"); rowType.writeObject(blockBuilder, rowBlockOf(ImmutableList.of(INTEGER, BIGINT), 13, 14L)); - Block expected = rowBlockOf(ImmutableList.of(new MapType(createUnboundedVarcharType(), rowType)), blockBuilder); + Block expected = rowBlockOf(ImmutableList.of(mapType(createUnboundedVarcharType(), rowType)), blockBuilder); assertBlockEquals(actual, expected); } @@ -248,7 +248,7 @@ public void testStructBlock() com.facebook.presto.spi.type.Type innerRowType = new RowType(ImmutableList.of(INTEGER, BIGINT), Optional.empty()); com.facebook.presto.spi.type.Type arrayOfInnerRowType = new ArrayType(innerRowType); - com.facebook.presto.spi.type.Type mapOfInnerRowType = new MapType(createUnboundedVarcharType(), innerRowType); + com.facebook.presto.spi.type.Type mapOfInnerRowType = mapType(createUnboundedVarcharType(), innerRowType); List outerRowParameterTypes = ImmutableList.of(TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, createUnboundedVarcharType(), createUnboundedVarcharType(), arrayOfInnerRowType, mapOfInnerRowType, innerRowType); com.facebook.presto.spi.type.Type outerRowType = new RowType(outerRowParameterTypes, Optional.empty()); @@ -290,7 +290,7 @@ public void testReuse() Type type = new TypeToken>() {}.getType(); ObjectInspector inspector = getInspector(type); - Block actual = getBlockObject(new MapType(createUnboundedVarcharType(), BIGINT), ImmutableMap.of(value, 0L), inspector); + Block actual = getBlockObject(mapType(createUnboundedVarcharType(), BIGINT), ImmutableMap.of(value, 0L), inspector); Block expected = mapBlockOf(createUnboundedVarcharType(), BIGINT, "bye", 0L); assertBlockEquals(actual, expected); diff --git a/presto-jdbc/pom.xml b/presto-jdbc/pom.xml index 8a8b27cffbf39..19978952d2083 100644 --- a/presto-jdbc/pom.xml +++ b/presto-jdbc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-jdbc @@ -20,6 +20,12 @@ com.facebook.presto presto-client + + + javax.validation + validation-api + + @@ -28,39 +34,8 @@ - io.airlift - http-client - - - - io.airlift - configuration - - - - io.airlift - trace-token - - - - com.google.inject - guice - - - com.google.inject.extensions - guice-multibindings - - - - org.weakref - jmxutils - - - - org.eclipse.jetty - jetty-servlet - - + com.squareup.okhttp3 + okhttp @@ -72,10 +47,6 @@ io.airlift json - - javax.inject - javax.inject - com.google.inject guice @@ -97,11 +68,6 @@ guava - - com.google.code.findbugs - annotations - - com.facebook.presto @@ -144,6 +110,12 @@ concurrent test + + + com.squareup.okhttp3 + mockwebserver + test + @@ -199,12 +171,12 @@ ${shadeBase}.joda.time - org.eclipse.jetty - ${shadeBase}.jetty + okhttp3 + ${shadeBase}.okhttp3 - org.HdrHistogram - ${shadeBase}.HdrHistogram + okio + ${shadeBase}.okio @@ -212,12 +184,8 @@ *:* META-INF/maven/** - META-INF/*.xml - META-INF/services/org.eclipse.** META-INF/services/com.fasterxml.** LICENSE - *.css - *.html @@ -226,10 +194,11 @@ ** + - javax.validation:validation-api + com.squareup.okhttp3:okhttp - ** + publicsuffixes.gz diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/AbstractConnectionProperty.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/AbstractConnectionProperty.java new file mode 100644 index 0000000000000..814a27e14f8d0 --- /dev/null +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/AbstractConnectionProperty.java @@ -0,0 +1,166 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import java.io.File; +import java.sql.DriverPropertyInfo; +import java.sql.SQLException; +import java.util.Optional; +import java.util.Properties; +import java.util.function.Predicate; + +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +abstract class AbstractConnectionProperty + implements ConnectionProperty +{ + private final String key; + private final Optional defaultValue; + private final Predicate isRequired; + private final Predicate isAllowed; + private final Converter converter; + + protected AbstractConnectionProperty( + String key, + Optional defaultValue, + Predicate isRequired, + Predicate isAllowed, + Converter converter) + { + this.key = requireNonNull(key, "key is null"); + this.defaultValue = requireNonNull(defaultValue, "defaultValue is null"); + this.isRequired = requireNonNull(isRequired, "isRequired is null"); + this.isAllowed = requireNonNull(isAllowed, "isAllowed is null"); + this.converter = requireNonNull(converter, "converter is null"); + } + + protected AbstractConnectionProperty( + String key, + Predicate required, + Predicate allowed, + Converter converter) + { + this(key, Optional.empty(), required, allowed, converter); + } + + @Override + public String getKey() + { + return key; + } + + @Override + public Optional getDefault() + { + return defaultValue; + } + + @Override + public DriverPropertyInfo getDriverPropertyInfo(Properties mergedProperties) + { + String currentValue = mergedProperties.getProperty(key); + DriverPropertyInfo result = new DriverPropertyInfo(key, currentValue); + result.required = isRequired.test(mergedProperties); + return result; + } + + @Override + public boolean isRequired(Properties properties) + { + return isRequired.test(properties); + } + + @Override + public boolean isAllowed(Properties properties) + { + return !properties.containsKey(key) || isAllowed.test(properties); + } + + @Override + public Optional getValue(Properties properties) + throws SQLException + { + String value = properties.getProperty(key); + if (value == null) { + if (isRequired(properties)) { + throw new SQLException(format("Connection property '%s' is required", key)); + } + return Optional.empty(); + } + if (value.isEmpty()) { + throw new SQLException(format("Connection property '%s' value is empty", key)); + } + + try { + return Optional.of(converter.convert(value)); + } + catch (RuntimeException e) { + throw new SQLException(format("Connection property '%s' value is invalid: %s", key, value), e); + } + } + + @Override + public void validate(Properties properties) + throws SQLException + { + if (!isAllowed(properties)) { + throw new SQLException(format("Connection property '%s' is not allowed", key)); + } + + getValue(properties); + } + + protected static final Predicate REQUIRED = properties -> true; + protected static final Predicate NOT_REQUIRED = properties -> false; + + protected static final Predicate ALLOWED = properties -> true; + + interface Converter + { + T convert(String value); + } + + protected static final Converter STRING_CONVERTER = value -> value; + protected static final Converter FILE_CONVERTER = File::new; + + protected static final Converter BOOLEAN_CONVERTER = value -> { + switch (value.toLowerCase(ENGLISH)) { + case "true": + return true; + case "false": + return false; + } + throw new IllegalArgumentException("value must be 'true' or 'false'"); + }; + + protected interface CheckedPredicate + { + boolean test(T t) + throws SQLException; + } + + protected static Predicate checkedPredicate(CheckedPredicate predicate) + { + return t -> { + try { + return predicate.test(t); + } + catch (SQLException e) { + return false; + } + }; + } +} diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java new file mode 100644 index 0000000000000..8ed6f7a63ecc6 --- /dev/null +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java @@ -0,0 +1,227 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.net.HostAndPort; + +import java.io.File; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.function.Predicate; + +import static com.facebook.presto.jdbc.AbstractConnectionProperty.checkedPredicate; +import static java.util.Collections.unmodifiableMap; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; + +final class ConnectionProperties +{ + public static final ConnectionProperty USER = new User(); + public static final ConnectionProperty PASSWORD = new Password(); + public static final ConnectionProperty SOCKS_PROXY = new SocksProxy(); + public static final ConnectionProperty HTTP_PROXY = new HttpProxy(); + public static final ConnectionProperty SSL = new Ssl(); + public static final ConnectionProperty SSL_TRUST_STORE_PATH = new SslTrustStorePath(); + public static final ConnectionProperty SSL_TRUST_STORE_PASSWORD = new SslTrustStorePassword(); + public static final ConnectionProperty KERBEROS_REMOTE_SERICE_NAME = new KerberosRemoteServiceName(); + public static final ConnectionProperty KERBEROS_USE_CANONICAL_HOSTNAME = new KerberosUseCanonicalHostname(); + public static final ConnectionProperty KERBEROS_PRINCIPAL = new KerberosPrincipal(); + public static final ConnectionProperty KERBEROS_CONFIG_PATH = new KerberosConfigPath(); + public static final ConnectionProperty KERBEROS_KEYTAB_PATH = new KerberosKeytabPath(); + public static final ConnectionProperty KERBEROS_CREDENTIAL_CACHE_PATH = new KerberosCredentialCachePath(); + + private static final Set> ALL_PROPERTIES = ImmutableSet.>builder() + .add(USER) + .add(PASSWORD) + .add(SOCKS_PROXY) + .add(HTTP_PROXY) + .add(SSL) + .add(SSL_TRUST_STORE_PATH) + .add(SSL_TRUST_STORE_PASSWORD) + .add(KERBEROS_REMOTE_SERICE_NAME) + .add(KERBEROS_USE_CANONICAL_HOSTNAME) + .add(KERBEROS_PRINCIPAL) + .add(KERBEROS_CONFIG_PATH) + .add(KERBEROS_KEYTAB_PATH) + .add(KERBEROS_CREDENTIAL_CACHE_PATH) + .build(); + + private static final Map> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream() + .collect(toMap(ConnectionProperty::getKey, identity()))); + + private static final Map DEFAULTS; + + static { + ImmutableMap.Builder defaults = ImmutableMap.builder(); + for (ConnectionProperty property : ALL_PROPERTIES) { + property.getDefault().ifPresent(value -> defaults.put(property.getKey(), value)); + } + DEFAULTS = defaults.build(); + } + + private ConnectionProperties() {} + + public static ConnectionProperty forKey(String propertiesKey) + { + return KEY_LOOKUP.get(propertiesKey); + } + + public static Set> allProperties() + { + return ALL_PROPERTIES; + } + + public static Map getDefaults() + { + return DEFAULTS; + } + + private static class User + extends AbstractConnectionProperty + { + public User() + { + super("user", REQUIRED, ALLOWED, STRING_CONVERTER); + } + } + + private static class Password + extends AbstractConnectionProperty + { + public Password() + { + super("password", NOT_REQUIRED, ALLOWED, STRING_CONVERTER); + } + } + + private static class SocksProxy + extends AbstractConnectionProperty + { + private static final Predicate NO_HTTP_PROXY = + checkedPredicate(properties -> !HTTP_PROXY.getValue(properties).isPresent()); + + public SocksProxy() + { + super("socksProxy", NOT_REQUIRED, NO_HTTP_PROXY, HostAndPort::fromString); + } + } + + private static class HttpProxy + extends AbstractConnectionProperty + { + private static final Predicate NO_SOCKS_PROXY = + checkedPredicate(properties -> !SOCKS_PROXY.getValue(properties).isPresent()); + + public HttpProxy() + { + super("httpProxy", NOT_REQUIRED, NO_SOCKS_PROXY, HostAndPort::fromString); + } + } + + private static class Ssl + extends AbstractConnectionProperty + { + public Ssl() + { + super("SSL", Optional.of("false"), NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER); + } + } + + private static class SslTrustStorePath + extends AbstractConnectionProperty + { + private static final Predicate IF_SSL_ENABLED = + checkedPredicate(properties -> SSL.getValue(properties).orElse(false)); + + public SslTrustStorePath() + { + super("SSLTrustStorePath", NOT_REQUIRED, IF_SSL_ENABLED, STRING_CONVERTER); + } + } + + private static class SslTrustStorePassword + extends AbstractConnectionProperty + { + private static final Predicate IF_TRUST_STORE = + checkedPredicate(properties -> SSL_TRUST_STORE_PATH.getValue(properties).isPresent()); + + public SslTrustStorePassword() + { + super("SSLTrustStorePassword", NOT_REQUIRED, IF_TRUST_STORE, STRING_CONVERTER); + } + } + + private static class KerberosRemoteServiceName + extends AbstractConnectionProperty + { + public KerberosRemoteServiceName() + { + super("KerberosRemoteServiceName", NOT_REQUIRED, ALLOWED, STRING_CONVERTER); + } + } + + private static Predicate isKerberosEnabled() + { + return checkedPredicate(properties -> KERBEROS_REMOTE_SERICE_NAME.getValue(properties).isPresent()); + } + + private static class KerberosPrincipal + extends AbstractConnectionProperty + { + public KerberosPrincipal() + { + super("KerberosPrincipal", NOT_REQUIRED, isKerberosEnabled(), STRING_CONVERTER); + } + } + + private static class KerberosUseCanonicalHostname + extends AbstractConnectionProperty + { + public KerberosUseCanonicalHostname() + { + super("KerberosUseCanonicalHostname", Optional.of("true"), isKerberosEnabled(), ALLOWED, BOOLEAN_CONVERTER); + } + } + + private static class KerberosConfigPath + extends AbstractConnectionProperty + { + public KerberosConfigPath() + { + super("KerberosConfigPath", NOT_REQUIRED, isKerberosEnabled(), FILE_CONVERTER); + } + } + + private static class KerberosKeytabPath + extends AbstractConnectionProperty + { + public KerberosKeytabPath() + { + super("KerberosKeytabPath", NOT_REQUIRED, isKerberosEnabled(), FILE_CONVERTER); + } + } + + private static class KerberosCredentialCachePath + extends AbstractConnectionProperty + { + public KerberosCredentialCachePath() + { + super("KerberosCredentialCachePath", NOT_REQUIRED, isKerberosEnabled(), FILE_CONVERTER); + } + } +} diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperty.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperty.java new file mode 100644 index 0000000000000..ac2e90897803d --- /dev/null +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperty.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import java.sql.DriverPropertyInfo; +import java.sql.SQLException; +import java.util.Optional; +import java.util.Properties; + +import static java.lang.String.format; + +interface ConnectionProperty +{ + String getKey(); + + Optional getDefault(); + + DriverPropertyInfo getDriverPropertyInfo(Properties properties); + + boolean isRequired(Properties properties); + + boolean isAllowed(Properties properties); + + Optional getValue(Properties properties) + throws SQLException; + + default T getRequiredValue(Properties properties) + throws SQLException + { + return getValue(properties).orElseThrow(() -> + new SQLException(format("Connection property '%s' is required", getKey()))); + } + + void validate(Properties properties) + throws SQLException; +} diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java index 5549ca5de473b..4e54ac4f3b397 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java @@ -38,6 +38,7 @@ import java.sql.Savepoint; import java.sql.Statement; import java.sql.Struct; +import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Properties; @@ -72,7 +73,7 @@ public class PrestoConnection private final AtomicReference transactionId = new AtomicReference<>(); private final QueryExecutor queryExecutor; - PrestoConnection(PrestoDriverUri uri, String user, QueryExecutor queryExecutor) + PrestoConnection(PrestoDriverUri uri, QueryExecutor queryExecutor) throws SQLException { requireNonNull(uri, "uri is null"); @@ -80,9 +81,10 @@ public class PrestoConnection this.httpUri = uri.getHttpUri(); this.schema.set(uri.getSchema()); this.catalog.set(uri.getCatalog()); + this.user = uri.getUser(); - this.user = requireNonNull(user, "user is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); + timeZoneId.set(TimeZone.getDefault().getID()); locale.set(Locale.getDefault()); } @@ -580,10 +582,13 @@ ServerInfo getServerInfo() return serverInfo.get(); } - StatementClient startQuery(String sql) + StatementClient startQuery(String sql, Map sessionPropertiesOverride) { String source = firstNonNull(clientInfo.get("ApplicationName"), "presto-jdbc"); + Map allProperties = new HashMap<>(sessionProperties); + allProperties.putAll(sessionPropertiesOverride); + ClientSession session = new ClientSession( httpUri, user, @@ -593,7 +598,7 @@ StatementClient startQuery(String sql) schema.get(), timeZoneId.get(), locale.get(), - ImmutableMap.copyOf(sessionProperties), + ImmutableMap.copyOf(allProperties), transactionId.get(), false, new Duration(2, MINUTES)); diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java index bfbfe35ac628c..9ee216947c5af 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriver.java @@ -14,6 +14,7 @@ package com.facebook.presto.jdbc; import com.google.common.base.Throwables; +import okhttp3.OkHttpClient; import java.io.Closeable; import java.sql.Connection; @@ -27,10 +28,9 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import static com.google.common.base.Strings.isNullOrEmpty; +import static com.facebook.presto.client.OkHttpUtil.userAgent; import static com.google.common.base.Strings.nullToEmpty; import static java.lang.Integer.parseInt; -import static java.lang.String.format; public class PrestoDriver implements Driver, Closeable @@ -40,13 +40,11 @@ public class PrestoDriver static final int DRIVER_VERSION_MAJOR; static final int DRIVER_VERSION_MINOR; - private static final DriverPropertyInfo[] DRIVER_PROPERTY_INFOS = {}; - private static final String DRIVER_URL_START = "jdbc:presto:"; - private static final String USER_PROPERTY = "user"; - - private final QueryExecutor queryExecutor; + private final OkHttpClient httpClient = new OkHttpClient().newBuilder() + .addInterceptor(userAgent(DRIVER_NAME + "/" + DRIVER_VERSION)) + .build(); static { String version = nullToEmpty(PrestoDriver.class.getPackage().getImplementationVersion()); @@ -70,15 +68,11 @@ public class PrestoDriver } } - public PrestoDriver() - { - this.queryExecutor = QueryExecutor.create(DRIVER_NAME + "/" + DRIVER_VERSION); - } - @Override public void close() { - queryExecutor.close(); + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); } @Override @@ -89,12 +83,13 @@ public Connection connect(String url, Properties info) return null; } - String user = info.getProperty(USER_PROPERTY); - if (isNullOrEmpty(user)) { - throw new SQLException(format("Username property (%s) must be set", USER_PROPERTY)); - } + PrestoDriverUri uri = new PrestoDriverUri(url, info); - return new PrestoConnection(new PrestoDriverUri(url), user, queryExecutor); + OkHttpClient.Builder builder = httpClient.newBuilder(); + uri.setupClient(builder); + QueryExecutor executor = new QueryExecutor(builder.build()); + + return new PrestoConnection(uri, executor); } @Override @@ -108,7 +103,11 @@ public boolean acceptsURL(String url) public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) throws SQLException { - return DRIVER_PROPERTY_INFOS; + Properties properties = new PrestoDriverUri(url, info).getProperties(); + + return ConnectionProperties.allProperties().stream() + .map(property -> property.getDriverPropertyInfo(properties)) + .toArray(DriverPropertyInfo[]::new); } @Override diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java index 35908043eff7d..d366d41e4aa89 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java @@ -13,18 +13,44 @@ */ package com.facebook.presto.jdbc; +import com.facebook.presto.client.ClientException; import com.google.common.base.Splitter; +import com.google.common.collect.Maps; import com.google.common.net.HostAndPort; +import okhttp3.OkHttpClient; +import java.io.File; import java.net.URI; import java.net.URISyntaxException; import java.sql.SQLException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Properties; +import static com.facebook.presto.client.KerberosUtil.defaultCredentialCachePath; +import static com.facebook.presto.client.OkHttpUtil.basicAuth; +import static com.facebook.presto.client.OkHttpUtil.setupHttpProxy; +import static com.facebook.presto.client.OkHttpUtil.setupKerberos; +import static com.facebook.presto.client.OkHttpUtil.setupSocksProxy; +import static com.facebook.presto.client.OkHttpUtil.setupSsl; +import static com.facebook.presto.jdbc.ConnectionProperties.HTTP_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_CONFIG_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_CREDENTIAL_CACHE_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_KEYTAB_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_PRINCIPAL; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_REMOTE_SERICE_NAME; +import static com.facebook.presto.jdbc.ConnectionProperties.KERBEROS_USE_CANONICAL_HOSTNAME; +import static com.facebook.presto.jdbc.ConnectionProperties.PASSWORD; +import static com.facebook.presto.jdbc.ConnectionProperties.SOCKS_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PASSWORD; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PATH; +import static com.facebook.presto.jdbc.ConnectionProperties.USER; import static com.google.common.base.Strings.isNullOrEmpty; -import static io.airlift.http.client.HttpUriBuilder.uriBuilder; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; /** @@ -40,25 +66,29 @@ final class PrestoDriverUri private final HostAndPort address; private final URI uri; + private final Properties properties; + private String catalog; private String schema; private final boolean useSecureConnection; - public PrestoDriverUri(String url) + public PrestoDriverUri(String url, Properties driverProperties) throws SQLException { - this(parseDriverUrl(url)); + this(parseDriverUrl(url), driverProperties); } - private PrestoDriverUri(URI uri) + private PrestoDriverUri(URI uri, Properties driverProperties) throws SQLException { this.uri = requireNonNull(uri, "uri is null"); - this.address = HostAndPort.fromParts(uri.getHost(), uri.getPort()); + address = HostAndPort.fromParts(uri.getHost(), uri.getPort()); + properties = mergeConnectionProperties(uri, driverProperties); + + validateConnectionProperties(properties); - Map params = parseParameters(uri.getQuery()); - useSecureConnection = Boolean.parseBoolean(params.get("secure")); + useSecureConnection = SSL.getRequiredValue(properties); initCatalogAndSchema(); } @@ -83,7 +113,64 @@ public URI getHttpUri() return buildHttpUri(); } + public String getUser() + throws SQLException + { + return USER.getRequiredValue(properties); + } + + public Properties getProperties() + { + return properties; + } + + public void setupClient(OkHttpClient.Builder builder) + throws SQLException + { + try { + setupSocksProxy(builder, SOCKS_PROXY.getValue(properties)); + setupHttpProxy(builder, HTTP_PROXY.getValue(properties)); + + // TODO: fix Tempto to allow empty passwords + String password = PASSWORD.getValue(properties).orElse(""); + if (!password.isEmpty() && !password.equals("***empty***")) { + if (!useSecureConnection) { + throw new SQLException("Authentication using username/password requires SSL to be enabled"); + } + builder.addInterceptor(basicAuth(getUser(), password)); + } + + if (useSecureConnection) { + Optional trustStorePath = SSL_TRUST_STORE_PATH.getValue(properties); + Optional trustStorePassword = SSL_TRUST_STORE_PASSWORD.getValue(properties); + setupSsl(builder, Optional.empty(), Optional.empty(), trustStorePath, trustStorePassword); + } + + if (KERBEROS_REMOTE_SERICE_NAME.getValue(properties).isPresent()) { + if (!useSecureConnection) { + throw new SQLException("Authentication using Kerberos requires SSL to be enabled"); + } + setupKerberos( + builder, + KERBEROS_REMOTE_SERICE_NAME.getRequiredValue(properties), + KERBEROS_USE_CANONICAL_HOSTNAME.getRequiredValue(properties), + KERBEROS_PRINCIPAL.getValue(properties), + KERBEROS_CONFIG_PATH.getValue(properties), + KERBEROS_KEYTAB_PATH.getValue(properties), + Optional.ofNullable(KERBEROS_CREDENTIAL_CACHE_PATH.getValue(properties) + .orElseGet(() -> defaultCredentialCachePath().map(File::new).orElse(null)))); + } + } + catch (ClientException e) { + throw new SQLException(e.getMessage(), e); + } + catch (RuntimeException e) { + throw new SQLException("Error setting up connection", e); + } + } + private static Map parseParameters(String query) + throws SQLException { Map result = new HashMap<>(); @@ -91,7 +178,9 @@ private static Map parseParameters(String query) Iterable queryArgs = QUERY_SPLITTER.split(query); for (String queryArg : queryArgs) { List parts = ARG_SPLITTER.splitToList(queryArg); - result.put(parts.get(0), parts.get(1)); + if (result.put(parts.get(0), parts.get(1)) != null) { + throw new SQLException(format("Connection property '%s' is in URL multiple times", parts.get(0))); + } } } @@ -122,12 +211,13 @@ private static URI parseDriverUrl(String url) private URI buildHttpUri() { - String scheme = (address.getPort() == 443 || useSecureConnection) ? "https" : "http"; - - return uriBuilder() - .scheme(scheme) - .host(address.getHostText()).port(address.getPort()) - .build(); + String scheme = useSecureConnection ? "https" : "http"; + try { + return new URI(scheme, null, address.getHost(), address.getPort(), null, null, null); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } } private void initCatalogAndSchema() @@ -166,4 +256,45 @@ private void initCatalogAndSchema() schema = parts.get(1); } } + + private static Properties mergeConnectionProperties(URI uri, Properties driverProperties) + throws SQLException + { + Map defaults = ConnectionProperties.getDefaults(); + Map urlProperties = parseParameters(uri.getQuery()); + Map suppliedProperties = Maps.fromProperties(driverProperties); + + for (String key : urlProperties.keySet()) { + if (suppliedProperties.containsKey(key)) { + throw new SQLException(format("Connection property '%s' is both in the URL and an argument", key)); + } + } + + Properties result = new Properties(); + setProperties(result, defaults); + setProperties(result, urlProperties); + setProperties(result, suppliedProperties); + return result; + } + + private static void setProperties(Properties properties, Map values) + { + for (Entry entry : values.entrySet()) { + properties.setProperty(entry.getKey(), entry.getValue()); + } + } + + private static void validateConnectionProperties(Properties connectionProperties) + throws SQLException + { + for (String propertyName : connectionProperties.stringPropertyNames()) { + if (ConnectionProperties.forKey(propertyName) == null) { + throw new SQLException(format("Unrecognized connection property '%s'", propertyName)); + } + } + + for (ConnectionProperty property : ConnectionProperties.allProperties()) { + property.validate(connectionProperties); + } + } } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java index a13893bcedb7d..a4293dd9b8562 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoStatement.java @@ -13,7 +13,9 @@ */ package com.facebook.presto.jdbc; +import com.facebook.presto.client.ClientException; import com.facebook.presto.client.StatementClient; +import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import java.sql.Connection; @@ -22,6 +24,7 @@ import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLWarning; import java.sql.Statement; +import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -194,6 +197,15 @@ public void setCursorName(String name) // ignore: positioned modifications not supported } + private Map getStatementSessionProperties() + { + ImmutableMap.Builder sessionProperties = ImmutableMap.builder(); + if (queryTimeoutSeconds.get() > 0) { + sessionProperties.put("query_max_run_time", queryTimeoutSeconds.get() + "s"); + } + return sessionProperties.build(); + } + @Override public boolean execute(String sql) throws SQLException @@ -204,7 +216,7 @@ public boolean execute(String sql) StatementClient client = null; ResultSet resultSet = null; try { - client = connection().startQuery(sql); + client = connection().startQuery(sql, getStatementSessionProperties()); if (client.isFailed()) { throw resultsException(client.finalResults()); } @@ -228,6 +240,9 @@ public boolean execute(String sql) return false; } + catch (ClientException e) { + throw new SQLException(e.getMessage(), e); + } catch (RuntimeException e) { throw new SQLException("Error executing query", e); } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java index e1964b1bde684..e9a587df38247 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/QueryExecutor.java @@ -13,103 +13,52 @@ */ package com.facebook.presto.jdbc; +import com.facebook.presto.client.ClientException; import com.facebook.presto.client.ClientSession; -import com.facebook.presto.client.QueryResults; +import com.facebook.presto.client.JsonResponse; import com.facebook.presto.client.ServerInfo; import com.facebook.presto.client.StatementClient; -import com.google.common.collect.ImmutableSet; -import com.google.common.net.HostAndPort; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClientConfig; -import io.airlift.http.client.Request; -import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.http.client.jetty.JettyIoPool; -import io.airlift.http.client.jetty.JettyIoPoolConfig; import io.airlift.json.JsonCodec; -import io.airlift.units.Duration; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.Request; -import javax.annotation.Nullable; - -import java.io.Closeable; -import java.net.InetSocketAddress; -import java.net.Proxy; -import java.net.ProxySelector; import java.net.URI; -import java.util.concurrent.TimeUnit; -import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; -import static io.airlift.http.client.Request.Builder.prepareGet; import static io.airlift.json.JsonCodec.jsonCodec; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; class QueryExecutor - implements Closeable { - private final JsonCodec queryInfoCodec; - private final JsonCodec serverInfoCodec; - private final HttpClient httpClient; + private static final JsonCodec SERVER_INFO_CODEC = jsonCodec(ServerInfo.class); + + private final OkHttpClient httpClient; - private QueryExecutor(JsonCodec queryResultsCodec, JsonCodec serverInfoCodec, HttpClient httpClient) + public QueryExecutor(OkHttpClient httpClient) { - this.queryInfoCodec = requireNonNull(queryResultsCodec, "queryResultsCodec is null"); - this.serverInfoCodec = requireNonNull(serverInfoCodec, "serverInfoCodec is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); } public StatementClient startQuery(ClientSession session, String query) { - return new StatementClient(httpClient, queryInfoCodec, session, query); - } - - @Override - public void close() - { - httpClient.close(); + return new StatementClient(httpClient, session, query); } public ServerInfo getServerInfo(URI server) { - URI uri = uriBuilderFrom(server).replacePath("/v1/info").build(); - Request request = prepareGet().setUri(uri).build(); - return httpClient.execute(request, createJsonResponseHandler(serverInfoCodec)); - } - - // TODO: replace this with a phantom reference - @SuppressWarnings("FinalizeDeclaration") - @Override - protected void finalize() - { - close(); - } - - static QueryExecutor create(String userAgent) - { - return create(new JettyHttpClient( - new HttpClientConfig() - .setConnectTimeout(new Duration(10, TimeUnit.SECONDS)) - .setSocksProxy(getSystemSocksProxy()), - new JettyIoPool("presto-jdbc", new JettyIoPoolConfig()), - ImmutableSet.of(new UserAgentRequestFilter(userAgent)))); - } + HttpUrl url = HttpUrl.get(server); + if (url == null) { + throw new ClientException("Invalid server URL: " + server); + } + url = url.newBuilder().encodedPath("/v1/info").build(); - static QueryExecutor create(HttpClient httpClient) - { - return new QueryExecutor(jsonCodec(QueryResults.class), jsonCodec(ServerInfo.class), httpClient); - } + Request request = new Request.Builder().url(url).build(); - @Nullable - private static HostAndPort getSystemSocksProxy() - { - URI uri = URI.create("socket://0.0.0.0:80"); - for (Proxy proxy : ProxySelector.getDefault().select(uri)) { - if (proxy.type() == Proxy.Type.SOCKS) { - if (proxy.address() instanceof InetSocketAddress) { - InetSocketAddress address = (InetSocketAddress) proxy.address(); - return HostAndPort.fromParts(address.getHostString(), address.getPort()); - } - } + JsonResponse response = JsonResponse.execute(SERVER_INFO_CODEC, httpClient, request); + if (!response.hasValue()) { + throw new RuntimeException(format("Request to %s failed: %s [Error: %s]", server, response, response.getResponseBody())); } - return null; + return response.getValue(); } } diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java index aade462a526d2..1e86e6b28df1c 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java @@ -16,6 +16,7 @@ import com.facebook.presto.execution.QueryState; import com.facebook.presto.plugin.blackhole.BlackHolePlugin; import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DateType; @@ -33,7 +34,6 @@ import com.facebook.presto.spi.type.VarbinaryType; import com.facebook.presto.tpch.TpchMetadata; import com.facebook.presto.tpch.TpchPlugin; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.ColorType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -75,6 +75,7 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.testing.Assertions.assertContains; import static io.airlift.testing.Assertions.assertInstanceOf; import static io.airlift.testing.Assertions.assertLessThan; import static io.airlift.units.Duration.nanosSince; @@ -1358,11 +1359,11 @@ public void testBadQuery() } } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Username property \\(user\\) must be set") + @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Connection property 'user' is required") public void testUserIsRequired() throws Exception { - try (Connection ignored = DriverManager.getConnection("jdbc:presto://test.invalid/")) { + try (Connection ignored = DriverManager.getConnection(format("jdbc:presto://%s", server.getAddress()))) { fail("expected exception"); } } @@ -1407,7 +1408,7 @@ public void testQueryCancellation() }); // start query and make sure it is not finished - queryStarted.await(10, SECONDS); + assertTrue(queryStarted.await(10, SECONDS)); assertNotNull(queryId.get()); assertFalse(getQueryState(queryId.get()).isDone()); @@ -1415,7 +1416,7 @@ public void testQueryCancellation() queryFuture.cancel(true); // make sure the query was aborted - queryFinished.await(10, SECONDS); + assertTrue(queryFinished.await(10, SECONDS)); assertNotNull(queryFailure.get()); assertEquals(getQueryState(queryId.get()), FAILED); @@ -1425,6 +1426,54 @@ public void testQueryCancellation() } } + @Test(timeOut = 4000) + public void testQueryTimeout() + throws Exception + { + try (Connection connection = createConnection("blackhole", "blackhole"); + Statement statement = connection.createStatement()) { + statement.executeUpdate("CREATE TABLE test_query_timeout (key BIGINT) " + + "WITH (" + + " split_count = 1, " + + " pages_per_split = 1, " + + " rows_per_page = 1, " + + " page_processing_delay = '1m'" + + ")"); + } + + CountDownLatch queryFinished = new CountDownLatch(1); + AtomicReference queryFailure = new AtomicReference<>(); + + executorService.submit(() -> { + try (Connection connection = createConnection("blackhole", "default"); + Statement statement = connection.createStatement()) { + statement.setQueryTimeout(1); + try (ResultSet resultSet = statement.executeQuery("SELECT * FROM test_query_timeout")) { + try { + resultSet.next(); + } + catch (SQLException t) { + queryFailure.set(t); + } + finally { + queryFinished.countDown(); + } + } + } + return null; + }); + + // make sure the query timed out + assertTrue(queryFinished.await(2, SECONDS)); + assertNotNull(queryFailure.get()); + assertContains(queryFailure.get().getMessage(), "Query exceeded maximum time limit of 1.00s"); + + try (Connection connection = createConnection("blackhole", "blackhole"); + Statement statement = connection.createStatement()) { + statement.executeUpdate("DROP TABLE test_query_timeout"); + } + } + private QueryState getQueryState(String queryId) throws SQLException { diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java index bb84f9132dc65..20f2d62fb3009 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriverUri.java @@ -17,87 +17,189 @@ import java.net.URI; import java.sql.SQLException; +import java.util.Properties; +import static com.facebook.presto.jdbc.ConnectionProperties.HTTP_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.SOCKS_PROXY; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PASSWORD; +import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PATH; import static java.lang.String.format; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.fail; public class TestPrestoDriverUri { - private static final String SERVER = "127.0.0.1:60429"; + @Test + public void testInvalidUrls() + { + // missing port + assertInvalid("jdbc:presto://localhost/", "No port number specified:"); + + // extra path segments + assertInvalid("jdbc:presto://localhost:8080/hive/default/abc", "Invalid path segments in URL:"); + + // extra slash + assertInvalid("jdbc:presto://localhost:8080//", "Catalog name is empty:"); + + // has schema but is missing catalog + assertInvalid("jdbc:presto://localhost:8080//default", "Catalog name is empty:"); + + // has catalog but schema is missing + assertInvalid("jdbc:presto://localhost:8080/a//", "Schema name is empty:"); + + // unrecognized property + assertInvalid("jdbc:presto://localhost:8080/hive/default?ShoeSize=13", "Unrecognized connection property 'ShoeSize'"); + + // empty property + assertInvalid("jdbc:presto://localhost:8080/hive/default?password=", "Connection property 'password' value is empty"); + + // property in url multiple times + assertInvalid("presto://localhost:8080/blackhole?password=a&password=b", "Connection property 'password' is in URL multiple times"); + + // property in both url and arguments + assertInvalid("presto://localhost:8080/blackhole?user=test123", "Connection property 'user' is both in the URL and an argument"); + + // setting both socks and http proxy + assertInvalid("presto://localhost:8080?socksProxy=localhost:1080&httpProxy=localhost:8888", "Connection property 'socksProxy' is not allowed"); + assertInvalid("presto://localhost:8080?httpProxy=localhost:8888&socksProxy=localhost:1080", "Connection property 'socksProxy' is not allowed"); + + // invalid ssl flag + assertInvalid("jdbc:presto://localhost:8080?SSL=0", "Connection property 'SSL' value is invalid: 0"); + assertInvalid("jdbc:presto://localhost:8080?SSL=1", "Connection property 'SSL' value is invalid: 1"); + assertInvalid("jdbc:presto://localhost:8080?SSL=2", "Connection property 'SSL' value is invalid: 2"); + assertInvalid("jdbc:presto://localhost:8080?SSL=abc", "Connection property 'SSL' value is invalid: abc"); - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Invalid path segments in URL: .*") - public void testBadUrlExtraPathSegments() + // ssl trust store password without path + assertInvalid("jdbc:presto://localhost:8080?SSL=true&SSLTrustStorePassword=password", "Connection property 'SSLTrustStorePassword' is not allowed"); + + // trust store path without ssl + assertInvalid("jdbc:presto://localhost:8080?SSLTrustStorePath=truststore.jks", "Connection property 'SSLTrustStorePath' is not allowed"); + + // trust store password without ssl + assertInvalid("jdbc:presto://localhost:8080?SSLTrustStorePassword=password", "Connection property 'SSLTrustStorePassword' is not allowed"); + + // kerberos config without service name + assertInvalid("jdbc:presto://localhost:8080?KerberosCredentialCachePath=/test", "Connection property 'KerberosCredentialCachePath' is not allowed"); + } + + @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Connection property 'user' is required") + public void testRequireUser() throws Exception { - String url = format("jdbc:presto://%s/hive/default/bad_string", SERVER); - new PrestoDriverUri(url); + new PrestoDriverUri("jdbc:presto://localhost:8080", new Properties()); } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Catalog name is empty: .*") - public void testBadUrlMissingCatalog() - throws Exception + @Test + void testUriWithSocksProxy() + throws SQLException { - String url = format("jdbc:presto://%s//default", SERVER); - new PrestoDriverUri(url); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080?socksProxy=localhost:1234"); + assertUriPortScheme(parameters, 8080, "http"); + + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(SOCKS_PROXY.getKey()), "localhost:1234"); } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Catalog name is empty: .*") - public void testBadUrlEndsInSlashes() - throws Exception + @Test + void testUriWithHttpProxy() + throws SQLException { - String url = format("jdbc:presto://%s//", SERVER); - new PrestoDriverUri(url); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080?httpProxy=localhost:5678"); + assertUriPortScheme(parameters, 8080, "http"); + + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(HTTP_PROXY.getKey()), "localhost:5678"); } - @Test(expectedExceptions = SQLException.class, expectedExceptionsMessageRegExp = "Schema name is empty: .*") - public void testBadUrlMissingSchema() - throws Exception + @Test + public void testUriWithoutSsl() + throws SQLException { - String url = format("jdbc:presto://%s/a//", SERVER); - new PrestoDriverUri(url); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole"); + assertUriPortScheme(parameters, 8080, "http"); } @Test - public void testUrlWithSsl() + public void testUriWithSslPortDoesNotUseSsl() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://some-ssl-server:443/blackhole"); + PrestoDriverUri parameters = createDriverUri("presto://somelocalhost:443/blackhole"); + assertUriPortScheme(parameters, 443, "http"); + } - URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 443); - assertEquals(uri.getScheme(), "https"); + @Test + public void testUriWithSslDisabled() + throws SQLException + { + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=false"); + assertUriPortScheme(parameters, 8080, "http"); } @Test - public void testUriWithSecureMissing() + public void testUriWithSslEnabled() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://localhost:8080/blackhole"); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=true"); + assertUriPortScheme(parameters, 8080, "https"); - URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 8080); - assertEquals(uri.getScheme(), "http"); + Properties properties = parameters.getProperties(); + assertNull(properties.getProperty(SSL_TRUST_STORE_PATH.getKey())); + assertNull(properties.getProperty(SSL_TRUST_STORE_PASSWORD.getKey())); } @Test - public void testUriWithSecureTrue() + public void testUriWithSslEnabledPathOnly() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://localhost:8080/blackhole?secure=true"); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=true&SSLTrustStorePath=truststore.jks"); + assertUriPortScheme(parameters, 8080, "https"); - URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 8080); - assertEquals(uri.getScheme(), "https"); + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(SSL_TRUST_STORE_PATH.getKey()), "truststore.jks"); + assertNull(properties.getProperty(SSL_TRUST_STORE_PASSWORD.getKey())); } @Test - public void testUriWithSecureFalse() + public void testUriWithSslEnabledPassword() throws SQLException { - PrestoDriverUri parameters = new PrestoDriverUri("presto://localhost:8080/blackhole?secure=false"); + PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?SSL=true&SSLTrustStorePath=truststore.jks&SSLTrustStorePassword=password"); + assertUriPortScheme(parameters, 8080, "https"); + Properties properties = parameters.getProperties(); + assertEquals(properties.getProperty(SSL_TRUST_STORE_PATH.getKey()), "truststore.jks"); + assertEquals(properties.getProperty(SSL_TRUST_STORE_PASSWORD.getKey()), "password"); + } + + private static void assertUriPortScheme(PrestoDriverUri parameters, int port, String scheme) + { URI uri = parameters.getHttpUri(); - assertEquals(uri.getPort(), 8080); - assertEquals(uri.getScheme(), "http"); + assertEquals(uri.getPort(), port); + assertEquals(uri.getScheme(), scheme); + } + + private static PrestoDriverUri createDriverUri(String url) + throws SQLException + { + Properties properties = new Properties(); + properties.setProperty("user", "test"); + + return new PrestoDriverUri(url, properties); + } + + private static void assertInvalid(String url, String prefix) + { + try { + createDriverUri(url); + fail("expected exception"); + } + catch (SQLException e) { + assertNotNull(e.getMessage()); + if (!e.getMessage().startsWith(prefix)) { + fail(format("expected:<%s> to start with <%s>", e.getMessage(), prefix)); + } + } } } diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java index 66c21737a46fa..f937aafd1eef4 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestProgressMonitor.java @@ -18,32 +18,28 @@ import com.facebook.presto.client.QueryResults; import com.facebook.presto.client.StatementStats; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.net.HttpHeaders; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpStatus; -import io.airlift.http.client.Request; -import io.airlift.http.client.Response; -import io.airlift.http.client.testing.TestingHttpClient; -import io.airlift.http.client.testing.TestingResponse; import io.airlift.json.JsonCodec; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import java.net.URI; +import java.io.IOException; import java.sql.Connection; +import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; -import java.util.Iterator; import java.util.List; import java.util.function.Consumer; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; import static io.airlift.json.JsonCodec.jsonCodec; import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static java.lang.String.format; -import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -51,28 +47,46 @@ @Test(singleThreaded = true) public class TestProgressMonitor { - private static final String SERVER_ADDRESS = "127.0.0.1:8080"; private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); - private static final String QUERY_ID = "20160128_214710_00012_rk68b"; - private static final String INFO_URI = "http://" + SERVER_ADDRESS + "/query.html?" + QUERY_ID; - private static final String PARTIAL_CANCEL_URI = "http://" + SERVER_ADDRESS + "/v1/stage/" + QUERY_ID + ".%d"; - private static final String NEXT_URI = "http://" + SERVER_ADDRESS + "/v1/statement/" + QUERY_ID + "/%d"; - private static final List RESPONSE_COLUMNS = ImmutableList.of(new Column("_col0", "bigint", new ClientTypeSignature("bigint", ImmutableList.of()))); - private static final List RESPONSES = ImmutableList.of( - newQueryResults(null, 1, null, null, "QUEUED"), - newQueryResults(1, 2, RESPONSE_COLUMNS, null, "RUNNING"), - newQueryResults(1, 3, RESPONSE_COLUMNS, null, "RUNNING"), - newQueryResults(0, 4, RESPONSE_COLUMNS, ImmutableList.of(ImmutableList.of(253161)), "RUNNING"), - newQueryResults(null, null, RESPONSE_COLUMNS, null, "FINISHED")); - - private static String newQueryResults(Integer partialCancelId, Integer nextUriId, List responseColumns, List> data, String state) + private MockWebServer server; + + @BeforeMethod + public void setup() + throws IOException + { + server = new MockWebServer(); + server.start(); + } + + @AfterMethod + public void teardown() + throws IOException + { + server.close(); + } + + private List createResults() + { + List columns = ImmutableList.of(new Column("_col0", "bigint", new ClientTypeSignature("bigint", ImmutableList.of()))); + return ImmutableList.builder() + .add(newQueryResults(null, 1, null, null, "QUEUED")) + .add(newQueryResults(1, 2, columns, null, "RUNNING")) + .add(newQueryResults(1, 3, columns, null, "RUNNING")) + .add(newQueryResults(0, 4, columns, ImmutableList.of(ImmutableList.of(253161)), "RUNNING")) + .add(newQueryResults(null, null, columns, null, "FINISHED")) + .build(); + } + + private String newQueryResults(Integer partialCancelId, Integer nextUriId, List responseColumns, List> data, String state) { + String queryId = "20160128_214710_00012_rk68b"; + QueryResults queryResults = new QueryResults( - QUERY_ID, - URI.create(INFO_URI), - partialCancelId == null ? null : URI.create(format(PARTIAL_CANCEL_URI, partialCancelId)), - nextUriId == null ? null : URI.create(format(NEXT_URI, nextUriId)), + queryId, + server.url("/query.html?" + queryId).uri(), + partialCancelId == null ? null : server.url(format("/v1/stage/%s.%s", queryId, partialCancelId)).uri(), + nextUriId == null ? null : server.url(format("/v1/statement/%s/%s", queryId, nextUriId)).uri(), responseColumns, data, new StatementStats(state, state.equals("QUEUED"), true, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, null), @@ -87,6 +101,12 @@ private static String newQueryResults(Integer partialCancelId, Integer nextUriId public void test() throws SQLException { + for (String result : createResults()) { + server.enqueue(new MockResponse() + .addHeader(CONTENT_TYPE, "application/json") + .setBody(result)); + } + try (Connection connection = createConnection()) { try (Statement statement = connection.createStatement()) { PrestoStatement prestoStatement = statement.unwrap(PrestoStatement.class); @@ -116,40 +136,15 @@ public void test() private Connection createConnection() throws SQLException { - HttpClient client = new TestingHttpClient(new TestingHttpClientProcessor(RESPONSES)); - QueryExecutor testQueryExecutor = QueryExecutor.create(client); - String uri = format("prestotest://%s", SERVER_ADDRESS); - return new PrestoConnection(new PrestoDriverUri(uri), "test", testQueryExecutor); - } - - private static class TestingHttpClientProcessor - implements TestingHttpClient.Processor - { - private final Iterator responses; - - public TestingHttpClientProcessor(List responses) - { - this.responses = ImmutableList.copyOf(requireNonNull(responses, "responses is null")).iterator(); - } - - @Override - public synchronized Response handle(Request request) - throws Exception - { - checkState(responses.hasNext(), "too many requests (ran out of test responses)"); - Response response = new TestingResponse( - HttpStatus.OK, - ImmutableListMultimap.of(HttpHeaders.CONTENT_TYPE, "application/json"), - responses.next().getBytes()); - return response; - } + String url = format("jdbc:presto://%s", server.url("/").uri().getAuthority()); + return DriverManager.getConnection(url, "test", null); } private static class RecordingProgressMonitor implements Consumer { private final ImmutableList.Builder builder = ImmutableList.builder(); - private boolean finished = false; + private boolean finished; @Override public synchronized void accept(QueryStats queryStats) diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java new file mode 100644 index 0000000000000..6693748e55ad3 --- /dev/null +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import com.facebook.presto.client.ServerInfo; +import io.airlift.json.JsonCodec; +import io.airlift.units.Duration; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.util.Optional; + +import static com.facebook.presto.client.NodeVersion.UNKNOWN; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static io.airlift.json.JsonCodec.jsonCodec; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestQueryExecutor +{ + private static final JsonCodec SERVER_INFO_CODEC = jsonCodec(ServerInfo.class); + + private MockWebServer server; + + @BeforeMethod + public void setup() + throws IOException + { + server = new MockWebServer(); + server.start(); + } + + @AfterMethod + public void teardown() + throws IOException + { + server.close(); + } + + @Test + public void testGetServerInfo() + throws Exception + { + ServerInfo expected = new ServerInfo(UNKNOWN, "test", true, Optional.of(Duration.valueOf("2m"))); + + server.enqueue(new MockResponse() + .addHeader(CONTENT_TYPE, "application/json") + .setBody(SERVER_INFO_CODEC.toJson(expected))); + + QueryExecutor executor = new QueryExecutor(new OkHttpClient()); + + ServerInfo actual = executor.getServerInfo(server.url("/v1/info").uri()); + assertEquals(actual.getEnvironment(), "test"); + assertEquals(actual.getUptime(), Optional.of(Duration.valueOf("2m"))); + + assertEquals(server.getRequestCount(), 1); + assertEquals(server.takeRequest().getPath(), "/v1/info"); + } +} diff --git a/presto-jmx/pom.xml b/presto-jmx/pom.xml index 38b456f736ac7..1ae370bd3ea5e 100644 --- a/presto-jmx/pom.xml +++ b/presto-jmx/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-jmx diff --git a/presto-kafka/pom.xml b/presto-kafka/pom.xml index 669da67c88c2d..ea2d6db6bcf09 100644 --- a/presto-kafka/pom.xml +++ b/presto-kafka/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-kafka diff --git a/presto-local-file/pom.xml b/presto-local-file/pom.xml index 562a1ac33d945..1e53188fe1c02 100644 --- a/presto-local-file/pom.xml +++ b/presto-local-file/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-local-file diff --git a/presto-main/etc/catalog/mysql.properties b/presto-main/etc/catalog/mysql.properties index db7463f53f77b..bb6b444c91d09 100644 --- a/presto-main/etc/catalog/mysql.properties +++ b/presto-main/etc/catalog/mysql.properties @@ -1,4 +1,4 @@ connector.name=mysql -connection-url=jdbc:mysql://mysql:13306/test +connection-url=jdbc:mysql://mysql:13306 connection-user=root connection-password=swarm diff --git a/presto-main/pom.xml b/presto-main/pom.xml index df470734f9fcd..dd535dce447c9 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-main @@ -324,6 +324,12 @@ tpch test + + + io.airlift + jaxrs-testing + test + diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 63433ee545275..c693ec2b28e63 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -61,7 +61,7 @@ public final class SystemSessionProperties public static final String INITIAL_SPLITS_PER_NODE = "initial_splits_per_node"; public static final String SPLIT_CONCURRENCY_ADJUSTMENT_INTERVAL = "split_concurrency_adjustment_interval"; public static final String OPTIMIZE_METADATA_QUERIES = "optimize_metadata_queries"; - public static final String FAST_INEQUALITY_JOIN = "fast_inequality_join"; + public static final String FAST_INEQUALITY_JOINS = "fast_inequality_joins"; public static final String QUERY_PRIORITY = "query_priority"; public static final String SPILL_ENABLED = "spill_enabled"; public static final String OPERATOR_MEMORY_LIMIT_BEFORE_SPILL = "operator_memory_limit_before_spill"; @@ -71,6 +71,8 @@ public final class SystemSessionProperties public static final String ITERATIVE_OPTIMIZER_TIMEOUT = "iterative_optimizer_timeout"; public static final String EXCHANGE_COMPRESSION = "exchange_compression"; public static final String ENABLE_INTERMEDIATE_AGGREGATIONS = "enable_intermediate_aggregations"; + public static final String PUSH_AGGREGATION_THROUGH_JOIN = "push_aggregation_through_join"; + public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join"; private final List> sessionProperties; @@ -238,8 +240,8 @@ public SystemSessionProperties( featuresConfig.isJoinReorderingEnabled(), false), booleanSessionProperty( - FAST_INEQUALITY_JOIN, - "Experimental: Use faster handling of inequality join if it is possible", + FAST_INEQUALITY_JOINS, + "Use faster handling of inequality join if it is possible", featuresConfig.isFastInequalityJoins(), false), booleanSessionProperty( @@ -306,6 +308,16 @@ public SystemSessionProperties( ENABLE_INTERMEDIATE_AGGREGATIONS, "Enable the use of intermediate aggregations", featuresConfig.isEnableIntermediateAggregations(), + false), + booleanSessionProperty( + PUSH_AGGREGATION_THROUGH_JOIN, + "Allow pushing aggregations below joins", + featuresConfig.isPushAggregationThroughJoin(), + false), + booleanSessionProperty( + PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, + "Push partial aggregations below joins", + false, false)); } @@ -401,7 +413,7 @@ public static boolean planWithTableNodePartitioning(Session session) public static boolean isFastInequalityJoin(Session session) { - return session.getSystemProperty(FAST_INEQUALITY_JOIN, Boolean.class); + return session.getSystemProperty(FAST_INEQUALITY_JOINS, Boolean.class); } public static boolean isJoinReorderingEnabled(Session session) @@ -477,4 +489,14 @@ public static boolean isEnableIntermediateAggregations(Session session) { return session.getSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, Boolean.class); } + + public static boolean shouldPushAggregationThroughJoin(Session session) + { + return session.getSystemProperty(PUSH_AGGREGATION_THROUGH_JOIN, Boolean.class); + } + + public static boolean isPushAggregationThroughJoin(Session session) + { + return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/block/BlockEncodingManager.java b/presto-main/src/main/java/com/facebook/presto/block/BlockEncodingManager.java index 01bb5e1533bc5..f841ab69f8cff 100644 --- a/presto-main/src/main/java/com/facebook/presto/block/BlockEncodingManager.java +++ b/presto-main/src/main/java/com/facebook/presto/block/BlockEncodingManager.java @@ -23,8 +23,10 @@ import com.facebook.presto.spi.block.IntArrayBlockEncoding; import com.facebook.presto.spi.block.InterleavedBlockEncoding; import com.facebook.presto.spi.block.LongArrayBlockEncoding; +import com.facebook.presto.spi.block.MapBlockEncoding; import com.facebook.presto.spi.block.RunLengthBlockEncoding; import com.facebook.presto.spi.block.ShortArrayBlockEncoding; +import com.facebook.presto.spi.block.SingleMapBlockEncoding; import com.facebook.presto.spi.block.SliceArrayBlockEncoding; import com.facebook.presto.spi.block.VariableWidthBlockEncoding; import com.facebook.presto.spi.type.TypeManager; @@ -71,6 +73,8 @@ public BlockEncodingManager(TypeManager typeManager, Set addBlockEncodingFactory(DictionaryBlockEncoding.FACTORY); addBlockEncodingFactory(ArrayBlockEncoding.FACTORY); addBlockEncodingFactory(InterleavedBlockEncoding.FACTORY); + addBlockEncodingFactory(MapBlockEncoding.FACTORY); + addBlockEncodingFactory(SingleMapBlockEncoding.FACTORY); addBlockEncodingFactory(RunLengthBlockEncoding.FACTORY); for (BlockEncodingFactory factory : requireNonNull(blockEncodingFactories, "blockEncodingFactories is null")) { diff --git a/presto-main/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java b/presto-main/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java index b764e9cee76ca..aebfc601c084a 100644 --- a/presto-main/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java +++ b/presto-main/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java @@ -22,12 +22,16 @@ import javax.inject.Inject; +import java.lang.invoke.MethodHandle; + import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; import static com.facebook.presto.util.Reflection.methodHandle; import static java.util.Objects.requireNonNull; public class KillQueryProcedure { + private static final MethodHandle KILL_QUERY = methodHandle(KillQueryProcedure.class, "killQuery", String.class); + private final QueryManager queryManager; @Inject @@ -48,6 +52,6 @@ public Procedure getProcedure() "runtime", "kill_query", ImmutableList.of(new Argument("query_id", VARCHAR)), - methodHandle(getClass(), "killQuery", String.class).bindTo(this)); + KILL_QUERY.bindTo(this)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java b/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java index 3786624e651aa..a4363f46b26f4 100644 --- a/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java +++ b/presto-main/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java @@ -26,11 +26,11 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; -import com.facebook.presto.type.ArrayType; import javax.inject.Inject; diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedCostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedCostCalculator.java new file mode 100644 index 0000000000000..f2326e379bcda --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CoefficientBasedCostCalculator.java @@ -0,0 +1,279 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.DomainTranslator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.OutputNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.plan.PlanVisitor; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.BooleanLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; + +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static com.facebook.presto.cost.PlanNodeCost.UNKNOWN_COST; + +/** + * Simple implementation of CostCalculator. It make many arbitrary decisions (e.g filtering selectivity, join matching). + * It serves POC purpose. To be replaced with more advanced implementation. + */ +@ThreadSafe +public class CoefficientBasedCostCalculator + implements CostCalculator +{ + private static final Double FILTER_COEFFICIENT = 0.5; + private static final Double JOIN_MATCHING_COEFFICIENT = 2.0; + + // todo some computation for outputSizeInBytes + + private final Metadata metadata; + + @Inject + public CoefficientBasedCostCalculator(Metadata metadata) + { + this.metadata = metadata; + } + + @Override + public Map calculateCostForPlan(Session session, Map types, PlanNode planNode) + { + Visitor visitor = new Visitor(session, types); + planNode.accept(visitor, null); + return ImmutableMap.copyOf(visitor.getCosts()); + } + + private class Visitor + extends PlanVisitor + { + private final Session session; + private final Map costs; + private final Map types; + + public Visitor(Session session, Map types) + { + this.costs = new HashMap<>(); + this.session = session; + this.types = ImmutableMap.copyOf(types); + } + + public Map getCosts() + { + return ImmutableMap.copyOf(costs); + } + + @Override + protected PlanNodeCost visitPlan(PlanNode node, Void context) + { + visitSources(node); + costs.put(node.getId(), UNKNOWN_COST); + return UNKNOWN_COST; + } + + @Override + public PlanNodeCost visitOutput(OutputNode node, Void context) + { + return copySourceCost(node); + } + + @Override + public PlanNodeCost visitFilter(FilterNode node, Void context) + { + PlanNodeCost sourceCost; + if (node.getSource() instanceof TableScanNode) { + sourceCost = visitTableScanWithPredicate((TableScanNode) node.getSource(), node.getPredicate()); + } + else { + sourceCost = visitSource(node); + } + + final double filterCoefficient = FILTER_COEFFICIENT; + PlanNodeCost filterCost = sourceCost + .mapOutputRowCount(value -> value * filterCoefficient); + costs.put(node.getId(), filterCost); + return filterCost; + } + + @Override + public PlanNodeCost visitProject(ProjectNode node, Void context) + { + return copySourceCost(node); + } + + @Override + public PlanNodeCost visitJoin(JoinNode node, Void context) + { + List sourceCosts = visitSources(node); + PlanNodeCost leftCost = sourceCosts.get(0); + PlanNodeCost rightCost = sourceCosts.get(1); + + PlanNodeCost.Builder joinCost = PlanNodeCost.builder(); + if (!leftCost.getOutputRowCount().isValueUnknown() && !rightCost.getOutputRowCount().isValueUnknown()) { + double rowCount = Math.max(leftCost.getOutputRowCount().getValue(), rightCost.getOutputRowCount().getValue()) * JOIN_MATCHING_COEFFICIENT; + joinCost.setOutputRowCount(new Estimate(rowCount)); + } + + costs.put(node.getId(), joinCost.build()); + return joinCost.build(); + } + + @Override + public PlanNodeCost visitExchange(ExchangeNode node, Void context) + { + List sourceCosts = visitSources(node); + Estimate rowCount = new Estimate(0); + for (PlanNodeCost sourceCost : sourceCosts) { + if (sourceCost.getOutputRowCount().isValueUnknown()) { + rowCount = Estimate.unknownValue(); + } + else { + rowCount = rowCount.map(value -> value + sourceCost.getOutputRowCount().getValue()); + } + } + + PlanNodeCost exchangeCost = PlanNodeCost.builder() + .setOutputRowCount(rowCount) + .build(); + costs.put(node.getId(), exchangeCost); + return exchangeCost; + } + + @Override + public PlanNodeCost visitTableScan(TableScanNode node, Void context) + { + return visitTableScanWithPredicate(node, BooleanLiteral.TRUE_LITERAL); + } + + private PlanNodeCost visitTableScanWithPredicate(TableScanNode node, Expression predicate) + { + Constraint constraint = getConstraint(node, predicate); + + TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable(), constraint); + PlanNodeCost tableScanCost = PlanNodeCost.builder() + .setOutputRowCount(tableStatistics.getRowCount()) + .build(); + + costs.put(node.getId(), tableScanCost); + return tableScanCost; + } + + private Constraint getConstraint(TableScanNode node, Expression predicate) + { + DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( + metadata, + session, + predicate, + types); + + TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain() + .transform(node.getAssignments()::get) + .intersect(node.getCurrentConstraint()); + + return new Constraint<>(simplifiedConstraint, bindings -> true); + } + + @Override + public PlanNodeCost visitValues(ValuesNode node, Void context) + { + Estimate valuesCount = new Estimate(node.getRows().size()); + PlanNodeCost valuesCost = PlanNodeCost.builder() + .setOutputRowCount(valuesCount) + .build(); + costs.put(node.getId(), valuesCost); + return valuesCost; + } + + @Override + public PlanNodeCost visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + { + visitSources(node); + PlanNodeCost nodeCost = PlanNodeCost.builder() + .setOutputRowCount(new Estimate(1.0)) + .build(); + costs.put(node.getId(), nodeCost); + return nodeCost; + } + + @Override + public PlanNodeCost visitSemiJoin(SemiJoinNode node, Void context) + { + visitSources(node); + PlanNodeCost sourceStatitics = costs.get(node.getSource().getId()); + PlanNodeCost semiJoinCost = sourceStatitics.mapOutputRowCount(rowCount -> rowCount * JOIN_MATCHING_COEFFICIENT); + costs.put(node.getId(), semiJoinCost); + return semiJoinCost; + } + + @Override + public PlanNodeCost visitLimit(LimitNode node, Void context) + { + PlanNodeCost sourceCost = visitSource(node); + PlanNodeCost.Builder limitCost = PlanNodeCost.builder(); + if (sourceCost.getOutputRowCount().getValue() < node.getCount()) { + limitCost.setOutputRowCount(sourceCost.getOutputRowCount()); + } + else { + limitCost.setOutputRowCount(new Estimate(node.getCount())); + } + costs.put(node.getId(), limitCost.build()); + return limitCost.build(); + } + + private PlanNodeCost copySourceCost(PlanNode node) + { + PlanNodeCost sourceCost = visitSource(node); + costs.put(node.getId(), sourceCost); + return sourceCost; + } + + private List visitSources(PlanNode node) + { + return node.getSources().stream() + .map(source -> source.accept(this, null)) + .collect(Collectors.toList()); + } + + private PlanNodeCost visitSource(PlanNode node) + { + return Iterables.getOnlyElement(visitSources(node)); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java new file mode 100644 index 0000000000000..f5d1f87743a21 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; + +import java.util.Map; + +/** + * Interface of cost calculator. + * + * It's responsibility is to provide approximation of cost of execution of plan node. + * Example implementations may be based on table statistics or data samples. + */ +public interface CostCalculator +{ + Map calculateCostForPlan(Session session, Map types, PlanNode planNode); + + default PlanNodeCost calculateCostForNode(Session session, Map types, PlanNode planNode) + { + return calculateCostForPlan(session, types, planNode).get(planNode.getId()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java new file mode 100644 index 0000000000000..c30eaa90538f3 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeCost.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.Estimate; + +import java.util.Objects; +import java.util.function.Function; + +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; +import static java.util.Objects.requireNonNull; + +public class PlanNodeCost +{ + public static final PlanNodeCost UNKNOWN_COST = PlanNodeCost.builder().build(); + + private final Estimate outputRowCount; + private final Estimate outputSizeInBytes; + + private PlanNodeCost(Estimate outputRowCount, Estimate outputSizeInBytes) + { + this.outputRowCount = requireNonNull(outputRowCount, "outputRowCount can not be null"); + this.outputSizeInBytes = requireNonNull(outputSizeInBytes, "outputSizeInBytes can not be null"); + } + + public Estimate getOutputRowCount() + { + return outputRowCount; + } + + public Estimate getOutputSizeInBytes() + { + return outputSizeInBytes; + } + + public PlanNodeCost mapOutputRowCount(Function mappingFunction) + { + return buildFrom(this).setOutputRowCount(outputRowCount.map(mappingFunction)).build(); + } + + public PlanNodeCost mapOutputSizeInBytes(Function mappingFunction) + { + return buildFrom(this).setOutputSizeInBytes(outputRowCount.map(mappingFunction)).build(); + } + + @Override + public String toString() + { + return "PlanNodeCost{outputRowCount=" + outputRowCount + ", outputSizeInBytes=" + outputSizeInBytes + '}'; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PlanNodeCost that = (PlanNodeCost) o; + return Objects.equals(outputRowCount, that.outputRowCount) && + Objects.equals(outputSizeInBytes, that.outputSizeInBytes); + } + + @Override + public int hashCode() + { + return Objects.hash(outputRowCount, outputSizeInBytes); + } + + public static Builder builder() + { + return new Builder(); + } + + public static Builder buildFrom(PlanNodeCost other) + { + return builder().setOutputRowCount(other.getOutputRowCount()) + .setOutputSizeInBytes(other.getOutputSizeInBytes()); + } + + public static final class Builder + { + private Estimate outputRowCount = unknownValue(); + private Estimate outputSizeInBytes = unknownValue(); + + public Builder setOutputRowCount(Estimate outputRowCount) + { + this.outputRowCount = outputRowCount; + return this; + } + + public Builder setOutputSizeInBytes(Estimate outputSizeInBytes) + { + this.outputSizeInBytes = outputSizeInBytes; + return this; + } + + public PlanNodeCost build() + { + return new PlanNodeCost(outputRowCount, outputSizeInBytes); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/event/query/QueryMonitor.java b/presto-main/src/main/java/com/facebook/presto/event/query/QueryMonitor.java index e5d25c384bf44..bf66e49ac3f2b 100644 --- a/presto-main/src/main/java/com/facebook/presto/event/query/QueryMonitor.java +++ b/presto-main/src/main/java/com/facebook/presto/event/query/QueryMonitor.java @@ -41,6 +41,7 @@ import com.facebook.presto.spi.eventlistener.SplitCompletedEvent; import com.facebook.presto.spi.eventlistener.SplitFailureInfo; import com.facebook.presto.spi.eventlistener.SplitStatistics; +import com.facebook.presto.spi.eventlistener.StageCpuDistribution; import com.facebook.presto.transaction.TransactionId; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -50,6 +51,8 @@ import io.airlift.json.JsonCodec; import io.airlift.log.Logger; import io.airlift.node.NodeInfo; +import io.airlift.stats.Distribution; +import io.airlift.stats.Distribution.DistributionSnapshot; import org.joda.time.DateTime; import javax.annotation.Nullable; @@ -111,6 +114,7 @@ public void queryCreatedEvent(QueryInfo queryInfo) queryInfo.getSession().getSource(), queryInfo.getSession().getCatalog(), queryInfo.getSession().getSchema(), + queryInfo.getResourceGroupName(), mergeSessionAndCatalogProperties(queryInfo), serverAddress, serverVersion, @@ -172,6 +176,11 @@ public void queryCompletedEvent(QueryInfo queryInfo) tableFinishInfo.map(TableFinishInfo::isJsonLengthLimitExceeded))); } + ImmutableList.Builder operatorSummaries = ImmutableList.builder(); + for (OperatorStats summary : queryInfo.getQueryStats().getOperatorSummaries()) { + operatorSummaries.add(objectMapper.writeValueAsString(summary)); + } + eventListenerManager.queryCompleted( new QueryCompletedEvent( new QueryMetadata( @@ -190,9 +199,11 @@ public void queryCompletedEvent(QueryInfo queryInfo) queryStats.getPeakMemoryReservation().toBytes(), queryStats.getRawInputDataSize().toBytes(), queryStats.getRawInputPositions(), + queryStats.getCumulativeMemory(), queryStats.getCompletedDrivers(), queryInfo.isCompleteInfo(), - objectMapper.writeValueAsString(queryInfo.getQueryStats().getOperatorSummaries())), + getCpuDistributions(queryInfo), + operatorSummaries.build()), new QueryContext( queryInfo.getSession().getUser(), queryInfo.getSession().getPrincipal(), @@ -202,6 +213,7 @@ public void queryCompletedEvent(QueryInfo queryInfo) queryInfo.getSession().getSource(), queryInfo.getSession().getCatalog(), queryInfo.getSession().getSchema(), + queryInfo.getResourceGroupName(), mergeSessionAndCatalogProperties(queryInfo), serverAddress, serverVersion, @@ -365,4 +377,49 @@ private void splitCompletedEvent(TaskId taskId, DriverStats driverStats, @Nullab log.error(e, "Error processing split completion event for task %s", taskId); } } + + private static List getCpuDistributions(QueryInfo queryInfo) + { + if (!queryInfo.getOutputStage().isPresent()) { + return ImmutableList.of(); + } + + ImmutableList.Builder builder = ImmutableList.builder(); + populateDistribution(queryInfo.getOutputStage().get(), builder); + + return builder.build(); + } + + private static void populateDistribution(StageInfo stageInfo, ImmutableList.Builder distributions) + { + distributions.add(computeCpuDistribution(stageInfo)); + for (StageInfo subStage : stageInfo.getSubStages()) { + populateDistribution(subStage, distributions); + } + } + + private static StageCpuDistribution computeCpuDistribution(StageInfo stageInfo) + { + Distribution cpuDistribution = new Distribution(); + + for (TaskInfo taskInfo : stageInfo.getTasks()) { + cpuDistribution.add(taskInfo.getStats().getTotalCpuTime().toMillis()); + } + + DistributionSnapshot snapshot = cpuDistribution.snapshot(); + + return new StageCpuDistribution( + stageInfo.getStageId().getId(), + stageInfo.getTasks().size(), + snapshot.getP25(), + snapshot.getP50(), + snapshot.getP75(), + snapshot.getP90(), + snapshot.getP95(), + snapshot.getP99(), + snapshot.getMin(), + snapshot.getMax(), + (long) snapshot.getTotal(), + snapshot.getTotal() / snapshot.getCount()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java b/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java index e03f96e8fb26d..733e54ad14846 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/CreateTableTask.java @@ -31,11 +31,12 @@ import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.TableElement; import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; -import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -45,6 +46,7 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_CATALOG; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; @@ -84,7 +86,7 @@ public ListenableFuture execute(CreateTable statement, TransactionManager tra return immediateFuture(null); } - List columns = new ArrayList<>(); + LinkedHashMap columns = new LinkedHashMap<>(); Map inheritedProperties = ImmutableMap.of(); boolean includingProperties = false; for (TableElement element : statement.getElements()) { @@ -94,7 +96,10 @@ public ListenableFuture execute(CreateTable statement, TransactionManager tra if ((type == null) || type.equals(UNKNOWN)) { throw new SemanticException(TYPE_MISMATCH, column, "Unknown type for column '%s' ", column.getName()); } - columns.add(new ColumnMetadata(column.getName(), type, column.getComment().orElse(null), false)); + if (columns.containsKey(column.getName().toLowerCase())) { + throw new SemanticException(DUPLICATE_COLUMN_NAME, column, "Column name '%s' specified more than once", column.getName()); + } + columns.put(column.getName().toLowerCase(), new ColumnMetadata(column.getName(), type, column.getComment().orElse(null), false)); } else if (element instanceof LikeClause) { LikeClause likeClause = (LikeClause) element; @@ -121,7 +126,12 @@ else if (element instanceof LikeClause) { likeTableMetadata.getColumns().stream() .filter(column -> !column.isHidden()) - .forEach(columns::add); + .forEach(column -> { + if (columns.containsKey(column.getName().toLowerCase())) { + throw new SemanticException(DUPLICATE_COLUMN_NAME, element, "Column name '%s' specified more than once", column.getName()); + } + columns.put(column.getName().toLowerCase(), column); + }); } else { throw new PrestoException(GENERIC_INTERNAL_ERROR, "Invalid TableElement: " + element.getClass().getName()); @@ -143,7 +153,7 @@ else if (element instanceof LikeClause) { Map finalProperties = combineProperties(statement.getProperties().keySet(), properties, inheritedProperties); - ConnectorTableMetadata tableMetadata = new ConnectorTableMetadata(tableName.asSchemaTableName(), columns, finalProperties, statement.getComment()); + ConnectorTableMetadata tableMetadata = new ConnectorTableMetadata(tableName.asSchemaTableName(), ImmutableList.copyOf(columns.values()), finalProperties, statement.getComment()); metadata.createTable(session, tableName.getCatalogName(), tableMetadata); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java index a7513d88dd4f6..02288e3b90c6c 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java @@ -21,6 +21,7 @@ import com.facebook.presto.security.AccessControl; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.transaction.TransactionManager; @@ -200,6 +201,12 @@ public QueryInfo getQueryInfo() return stateMachine.updateQueryInfo(Optional.empty()); } + @Override + public Plan getQueryPlan() + { + throw new UnsupportedOperationException(); + } + @Override public QueryState getState() { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java index 6f7734cc995dd..20ee5867e80d0 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/FailedQueryExecution.java @@ -19,6 +19,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.transaction.TransactionManager; import io.airlift.units.Duration; @@ -66,6 +67,12 @@ public QueryState getState() return queryInfo.getState(); } + @Override + public Plan getQueryPlan() + { + throw new UnsupportedOperationException(); + } + @Override public VersionedMemoryPoolId getMemoryPool() { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/GrantTask.java b/presto-main/src/main/java/com/facebook/presto/execution/GrantTask.java index 9e86608f83566..e6360af38e06a 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/GrantTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/GrantTask.java @@ -68,7 +68,7 @@ public ListenableFuture execute(Grant statement, TransactionManager transacti // verify current identity has permissions to grant permissions for (Privilege privilege : privileges) { - accessControl.checkCanGrantTablePrivilege(session.getRequiredTransactionId(), session.getIdentity(), privilege, tableName); + accessControl.checkCanGrantTablePrivilege(session.getRequiredTransactionId(), session.getIdentity(), privilege, tableName, statement.getGrantee(), statement.isWithGrantOption()); } metadata.grantTablePrivileges(session, tableName, privileges, statement.getGrantee(), statement.isWithGrantOption()); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java index 92de2e918a3cb..098ed4dbecc7f 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/QueryExecution.java @@ -18,6 +18,7 @@ import com.facebook.presto.memory.VersionedMemoryPoolId; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Statement; import io.airlift.units.Duration; @@ -37,6 +38,8 @@ public interface QueryExecution void setResourceGroup(ResourceGroupId resourceGroupId); + Plan getQueryPlan(); + Duration waitForStateChange(QueryState currentState, Duration maxWait) throws InterruptedException; diff --git a/presto-main/src/main/java/com/facebook/presto/execution/QueryManager.java b/presto-main/src/main/java/com/facebook/presto/execution/QueryManager.java index ba609a64c27cf..66916a1a25a50 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/QueryManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/QueryManager.java @@ -16,6 +16,7 @@ import com.facebook.presto.server.SessionSupplier; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.sql.planner.Plan; import io.airlift.units.Duration; import java.util.List; @@ -32,6 +33,8 @@ Duration waitForStateChange(QueryId queryId, QueryState currentState, Duration m Optional getQueryResourceGroup(QueryId queryId); + Plan getQueryPlan(QueryId queryId); + Optional getQueryState(QueryId queryId); void recordHeartbeat(QueryId queryId); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/QueryStateMachine.java b/presto-main/src/main/java/com/facebook/presto/execution/QueryStateMachine.java index 04ba6201dffe5..d03aed90a6809 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/QueryStateMachine.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/QueryStateMachine.java @@ -304,6 +304,7 @@ public QueryInfo getQueryInfo(Optional rootStage) int totalDrivers = 0; int queuedDrivers = 0; int runningDrivers = 0; + int blockedDrivers = 0; int completedDrivers = 0; long cumulativeMemory = 0; @@ -338,6 +339,7 @@ public QueryInfo getQueryInfo(Optional rootStage) totalDrivers += stageStats.getTotalDrivers(); queuedDrivers += stageStats.getQueuedDrivers(); runningDrivers += stageStats.getRunningDrivers(); + blockedDrivers += stageStats.getBlockedDrivers(); completedDrivers += stageStats.getCompletedDrivers(); cumulativeMemory += stageStats.getCumulativeMemory(); @@ -393,6 +395,7 @@ public QueryInfo getQueryInfo(Optional rootStage) totalDrivers, queuedDrivers, runningDrivers, + blockedDrivers, completedDrivers, cumulativeMemory, @@ -610,13 +613,7 @@ public void onFailure(Throwable throwable) private boolean transitionToFinished() { - try { - metadata.cleanupQuery(session); - } - catch (Throwable t) { - log.error("Error during cleanupQuery: %s", t); - } - + cleanupQueryQuietly(); recordDoneStats(); return queryState.setIf(FINISHED, currentState -> !currentState.isDone()); @@ -624,20 +621,13 @@ private boolean transitionToFinished() public boolean transitionToFailed(Throwable throwable) { - try { - metadata.cleanupQuery(session); - } - catch (Throwable t) { - log.error("Error during cleanupQuery: %s", t); - } - - requireNonNull(throwable, "throwable is null"); - + cleanupQueryQuietly(); recordDoneStats(); // NOTE: The failure cause must be set before triggering the state change, so // listeners can observe the exception. This is safe because the failure cause // can only be observed if the transition to FAILED is successful. + requireNonNull(throwable, "throwable is null"); failureCause.compareAndSet(null, toFailure(throwable)); boolean failed = queryState.setIf(FAILED, currentState -> !currentState.isDone()); @@ -654,6 +644,7 @@ public boolean transitionToFailed(Throwable throwable) public boolean transitionToCanceled() { + cleanupQueryQuietly(); recordDoneStats(); // NOTE: The failure cause must be set before triggering the state change, so @@ -669,6 +660,16 @@ public boolean transitionToCanceled() return canceled; } + private void cleanupQueryQuietly() + { + try { + metadata.cleanupQuery(session); + } + catch (Throwable t) { + log.error("Error cleaning up query: %s", t); + } + } + private void recordDoneStats() { Duration durationSinceCreation = nanosSince(createNanos).convertToMostSuccinctTimeUnit(); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/QueryStats.java b/presto-main/src/main/java/com/facebook/presto/execution/QueryStats.java index d162b55622fd6..5188e62ba45d2 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/QueryStats.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/QueryStats.java @@ -56,6 +56,7 @@ public class QueryStats private final int totalDrivers; private final int queuedDrivers; private final int runningDrivers; + private final int blockedDrivers; private final int completedDrivers; private final double cumulativeMemory; @@ -96,6 +97,7 @@ public QueryStats() this.finishingTime = null; this.totalTasks = 0; this.runningTasks = 0; + this.blockedDrivers = 0; this.completedTasks = 0; this.totalDrivers = 0; this.queuedDrivers = 0; @@ -141,6 +143,7 @@ public QueryStats( @JsonProperty("totalDrivers") int totalDrivers, @JsonProperty("queuedDrivers") int queuedDrivers, @JsonProperty("runningDrivers") int runningDrivers, + @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, @JsonProperty("cumulativeMemory") double cumulativeMemory, @@ -191,6 +194,8 @@ public QueryStats( this.queuedDrivers = queuedDrivers; checkArgument(runningDrivers >= 0, "runningDrivers is negative"); this.runningDrivers = runningDrivers; + checkArgument(blockedDrivers >= 0, "blockedDrivers is negative"); + this.blockedDrivers = blockedDrivers; checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; @@ -325,6 +330,12 @@ public int getRunningDrivers() return runningDrivers; } + @JsonProperty + public int getBlockedDrivers() + { + return blockedDrivers; + } + @JsonProperty public int getCompletedDrivers() { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/RevokeTask.java b/presto-main/src/main/java/com/facebook/presto/execution/RevokeTask.java index 889078ee905ba..73c7d619594f9 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/RevokeTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/RevokeTask.java @@ -68,7 +68,7 @@ public ListenableFuture execute(Revoke statement, TransactionManager transact // verify current identity has permissions to revoke permissions for (Privilege privilege : privileges) { - accessControl.checkCanRevokeTablePrivilege(session.getRequiredTransactionId(), session.getIdentity(), privilege, tableName); + accessControl.checkCanRevokeTablePrivilege(session.getRequiredTransactionId(), session.getIdentity(), privilege, tableName, statement.getGrantee(), statement.isGrantOptionFor()); } metadata.revokeTablePrivileges(session, tableName, privileges, statement.getGrantee(), statement.isGrantOptionFor()); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java index 1f652770995d2..87067b1fdb609 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java @@ -18,6 +18,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.execution.scheduler.ExecutionPolicy; import com.facebook.presto.execution.scheduler.NodeScheduler; @@ -104,7 +105,9 @@ public final class SqlQueryExecution private final FailureDetector failureDetector; private final QueryExplainer queryExplainer; + private final CostCalculator costCalculator; private final AtomicReference queryScheduler = new AtomicReference<>(); + private final AtomicReference queryPlan = new AtomicReference<>(); private final NodeTaskMap nodeTaskMap; private final ExecutionPolicy executionPolicy; private final List parameters; @@ -122,6 +125,7 @@ public SqlQueryExecution(QueryId queryId, SplitManager splitManager, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + CostCalculator costCalculator, List planOptimizers, RemoteTaskFactory remoteTaskFactory, LocationFactory locationFactory, @@ -142,6 +146,7 @@ public SqlQueryExecution(QueryId queryId, this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); this.locationFactory = requireNonNull(locationFactory, "locationFactory is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -303,8 +308,9 @@ private PlanRoot doAnalyzeQuery() // plan query PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser); + LogicalPlanner logicalPlanner = new LogicalPlanner(stateMachine.getSession(), planOptimizers, idAllocator, metadata, sqlParser, costCalculator); Plan plan = logicalPlanner.plan(analysis); + queryPlan.set(plan); // extract inputs List inputs = new InputExtractor(metadata, stateMachine.getSession()).extractInputs(plan.getRoot()); @@ -504,6 +510,11 @@ public void setResourceGroup(ResourceGroupId resourceGroupId) stateMachine.setResourceGroup(resourceGroupId); } + public Plan getQueryPlan() + { + return queryPlan.get(); + } + private QueryInfo buildQueryInfo(SqlQueryScheduler scheduler) { Optional stageInfo = Optional.empty(); @@ -560,6 +571,7 @@ public static class SqlQueryExecutionFactory private final SplitManager splitManager; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final CostCalculator costCalculator; private final List planOptimizers; private final RemoteTaskFactory remoteTaskFactory; private final TransactionManager transactionManager; @@ -580,6 +592,7 @@ public static class SqlQueryExecutionFactory SplitManager splitManager, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + CostCalculator costCalculator, PlanOptimizers planOptimizers, RemoteTaskFactory remoteTaskFactory, TransactionManager transactionManager, @@ -610,7 +623,7 @@ public static class SqlQueryExecutionFactory this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); this.executionPolicies = requireNonNull(executionPolicies, "schedulerPolicies is null"); - + this.costCalculator = requireNonNull(costCalculator, "cost calculator is null"); this.planOptimizers = planOptimizers.get(); } @@ -634,6 +647,7 @@ public SqlQueryExecution createQueryExecution(QueryId queryId, String query, Ses splitManager, nodePartitioningManager, nodeScheduler, + costCalculator, planOptimizers, remoteTaskFactory, locationFactory, diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java index f177420dd076a..34ceb5f892893 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java @@ -30,6 +30,7 @@ import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.parser.ParsingException; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.Expression; @@ -293,6 +294,19 @@ public Optional getQueryResourceGroup(QueryId queryId) return Optional.empty(); } + @Override + public Plan getQueryPlan(QueryId queryId) + { + requireNonNull(queryId, "queryId is null"); + + QueryExecution query = queries.get(queryId); + if (query == null) { + throw new NoSuchElementException(); + } + + return query.getQueryPlan(); + } + @Override public Optional getQueryState(QueryId queryId) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManagerStats.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManagerStats.java index 4c1aa06dc7391..f74cf3d204275 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManagerStats.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManagerStats.java @@ -24,6 +24,7 @@ import static com.facebook.presto.spi.StandardErrorCode.ABANDONED_QUERY; import static com.facebook.presto.spi.StandardErrorCode.USER_CANCELED; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; public class SqlQueryManagerStats { @@ -37,6 +38,9 @@ public class SqlQueryManagerStats private final CounterStat internalFailures = new CounterStat(); private final CounterStat externalFailures = new CounterStat(); private final CounterStat insufficientResourcesFailures = new CounterStat(); + private final CounterStat consumedInputRows = new CounterStat(); + private final CounterStat consumedInputBytes = new CounterStat(); + private final CounterStat consumedCpuTimeSecs = new CounterStat(); private final TimeStat executionTime = new TimeStat(MILLISECONDS); private final DistributionStat wallInputBytesRate = new DistributionStat(); private final DistributionStat cpuInputByteRate = new DistributionStat(); @@ -58,8 +62,12 @@ public void queryFinished(QueryInfo info) long rawInputBytes = info.getQueryStats().getRawInputDataSize().toBytes(); - long executionWallMillis = info.getQueryStats().getEndTime().getMillis() - info.getQueryStats().getCreateTime().getMillis(); - executionTime.add(executionWallMillis, MILLISECONDS); + consumedCpuTimeSecs.update((long) info.getQueryStats().getTotalCpuTime().getValue(SECONDS)); + consumedInputBytes.update(info.getQueryStats().getRawInputDataSize().toBytes()); + consumedInputRows.update(info.getQueryStats().getRawInputPositions()); + executionTime.add(info.getQueryStats().getExecutionTime()); + + long executionWallMillis = info.getQueryStats().getExecutionTime().toMillis(); if (executionWallMillis > 0) { wallInputBytesRate.add(rawInputBytes * 1000 / executionWallMillis); } @@ -123,6 +131,27 @@ public CounterStat getFailedQueries() return failedQueries; } + @Managed + @Nested + public CounterStat getConsumedInputRows() + { + return consumedInputRows; + } + + @Managed + @Nested + public CounterStat getConsumedInputBytes() + { + return consumedInputBytes; + } + + @Managed + @Nested + public CounterStat getConsumedCpuTimeSecs() + { + return consumedCpuTimeSecs; + } + @Managed @Nested public TimeStat getExecutionTime() diff --git a/presto-main/src/main/java/com/facebook/presto/execution/StageStateMachine.java b/presto-main/src/main/java/com/facebook/presto/execution/StageStateMachine.java index d21e8c60c07a3..8f7ff3bc107d7 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/StageStateMachine.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/StageStateMachine.java @@ -214,6 +214,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier, Su int totalDrivers = 0; int queuedDrivers = 0; int runningDrivers = 0; + int blockedDrivers = 0; int completedDrivers = 0; long cumulativeMemory = 0; @@ -253,6 +254,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier, Su totalDrivers += taskStats.getTotalDrivers(); queuedDrivers += taskStats.getQueuedDrivers(); runningDrivers += taskStats.getRunningDrivers(); + blockedDrivers += taskStats.getBlockedDrivers(); completedDrivers += taskStats.getCompletedDrivers(); cumulativeMemory += taskStats.getCumulativeMemory(); @@ -298,6 +300,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier, Su totalDrivers, queuedDrivers, runningDrivers, + blockedDrivers, completedDrivers, cumulativeMemory, diff --git a/presto-main/src/main/java/com/facebook/presto/execution/StageStats.java b/presto-main/src/main/java/com/facebook/presto/execution/StageStats.java index 63c79510703c4..c164469a492fd 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/StageStats.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/StageStats.java @@ -49,6 +49,7 @@ public class StageStats private final int totalDrivers; private final int queuedDrivers; private final int runningDrivers; + private final int blockedDrivers; private final int completedDrivers; private final double cumulativeMemory; @@ -86,6 +87,7 @@ public StageStats() this.totalDrivers = 0; this.queuedDrivers = 0; this.runningDrivers = 0; + this.blockedDrivers = 0; this.completedDrivers = 0; this.cumulativeMemory = 0.0; this.totalMemoryReservation = null; @@ -121,6 +123,7 @@ public StageStats( @JsonProperty("totalDrivers") int totalDrivers, @JsonProperty("queuedDrivers") int queuedDrivers, @JsonProperty("runningDrivers") int runningDrivers, + @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, @JsonProperty("cumulativeMemory") double cumulativeMemory, @@ -163,6 +166,8 @@ public StageStats( this.queuedDrivers = queuedDrivers; checkArgument(runningDrivers >= 0, "runningDrivers is negative"); this.runningDrivers = runningDrivers; + checkArgument(blockedDrivers >= 0, "blockedDrivers is negative"); + this.blockedDrivers = blockedDrivers; checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; @@ -252,6 +257,12 @@ public int getRunningDrivers() return runningDrivers; } + @JsonProperty + public int getBlockedDrivers() + { + return blockedDrivers; + } + @JsonProperty public int getCompletedDrivers() { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java b/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java index 2b3b7010d4fae..6652791540a8b 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java @@ -227,6 +227,9 @@ public synchronized TaskHandle addTask(TaskId taskId, DoubleSupplier utilization { requireNonNull(taskId, "taskId is null"); requireNonNull(utilizationSupplier, "utilizationSupplier is null"); + + log.debug("Task scheduled " + taskId); + TaskHandle taskHandle = new TaskHandle(taskId, utilizationSupplier, initialSplitConcurrency, splitConcurrencyAdjustFrequency); tasks.add(taskHandle); return taskHandle; @@ -256,6 +259,8 @@ public void removeTask(TaskHandle taskHandle) int priorityLevel = calculatePriorityLevel(threadUsageNanos); completedTasksPerLevel.incrementAndGet(priorityLevel); + log.debug("Task finished or failed " + taskHandle.getTaskId()); + // replace blocked splits that were terminated addNewEntrants(); } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/FifoQueue.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/FifoQueue.java index 03d944c85b7a0..ba79aa705d29d 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/FifoQueue.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/FifoQueue.java @@ -74,4 +74,10 @@ public boolean isEmpty() { return delegate.isEmpty(); } + + @Override + public Iterator iterator() + { + return delegate.iterator(); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/IndexedPriorityQueue.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/IndexedPriorityQueue.java index 11158dee02081..6b288344f5397 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/IndexedPriorityQueue.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/IndexedPriorityQueue.java @@ -20,6 +20,7 @@ import java.util.TreeSet; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterators.transform; import static java.util.Objects.requireNonNull; /** @@ -111,6 +112,12 @@ public boolean isEmpty() return queue.isEmpty(); } + @Override + public Iterator iterator() + { + return transform(queue.iterator(), Entry::getValue); + } + private static final class Entry { private final E value; diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java index 3b7f6f5c84af7..f7f4b396f4cc0 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java @@ -15,6 +15,9 @@ import com.facebook.presto.execution.QueryExecution; import com.facebook.presto.execution.QueryState; +import com.facebook.presto.server.QueryStateInfo; +import com.facebook.presto.server.ResourceGroupStateInfo; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.resourceGroups.ResourceGroup; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; @@ -39,7 +42,9 @@ import java.util.function.BiConsumer; import static com.facebook.presto.SystemSessionProperties.getQueryPriority; +import static com.facebook.presto.server.QueryStateInfo.createQueryStateInfo; import static com.facebook.presto.spi.ErrorType.USER_ERROR; +import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_TIME_LIMIT; import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_QUEUE; import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_RUN; import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.FULL; @@ -51,6 +56,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.units.DataSize.Unit.BYTE; +import static java.lang.Math.min; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -112,6 +118,10 @@ public class InternalResourceGroup private SchedulingPolicy schedulingPolicy = FAIR; @GuardedBy("root") private boolean jmxExport; + @GuardedBy("root") + private Duration queuedTimeLimit = new Duration(Long.MAX_VALUE, MILLISECONDS); + @GuardedBy("root") + private Duration runningTimeLimit = new Duration(Long.MAX_VALUE, MILLISECONDS); protected InternalResourceGroup(Optional parent, String name, BiConsumer jmxExportListener, Executor executor) { @@ -138,23 +148,14 @@ public ResourceGroupInfo getInfo() .map(InternalResourceGroup::getInfo) .collect(toImmutableList()); - ResourceGroupState resourceGroupState; - if (canRunMore()) { - resourceGroupState = CAN_RUN; - } - else if (canQueueMore()) { - resourceGroupState = CAN_QUEUE; - } - else { - resourceGroupState = FULL; - } - return new ResourceGroupInfo( id, new DataSize(softMemoryLimitBytes, BYTE), maxRunningQueries, + runningTimeLimit, maxQueuedQueries, - resourceGroupState, + queuedTimeLimit, + getState(), eligibleSubGroups.size(), new DataSize(cachedMemoryUsageBytes, BYTE), runningQueries.size() + descendantRunningQueries, @@ -163,6 +164,54 @@ else if (canQueueMore()) { } } + public ResourceGroupStateInfo getStateInfo() + { + synchronized (root) { + return new ResourceGroupStateInfo( + id, + getState(), + DataSize.succinctBytes(softMemoryLimitBytes), + DataSize.succinctBytes(cachedMemoryUsageBytes), + maxRunningQueries, + maxQueuedQueries, + runningTimeLimit, + queuedTimeLimit, + getAggregatedRunningQueriesInfo(), + queuedQueries.size() + descendantQueuedQueries); + } + } + + private ResourceGroupState getState() + { + synchronized (root) { + if (canRunMore()) { + return CAN_RUN; + } + else if (canQueueMore()) { + return CAN_QUEUE; + } + else { + return FULL; + } + } + } + + private List getAggregatedRunningQueriesInfo() + { + synchronized (root) { + if (subGroups.isEmpty()) { + return runningQueries.stream() + .map(QueryExecution::getQueryInfo) + .map(queryInfo -> createQueryStateInfo(queryInfo, Optional.of(id), Optional.empty())) + .collect(toImmutableList()); + } + return subGroups.values().stream() + .map(InternalResourceGroup::getAggregatedRunningQueriesInfo) + .flatMap(List::stream) + .collect(toImmutableList()); + } + } + @Override public ResourceGroupId getId() { @@ -185,6 +234,27 @@ public int getQueuedQueries() } } + @Managed + public int getWaitingQueuedQueries() + { + synchronized (root) { + // For leaf group, when no queries can run, all queued queries are waiting for resources on this resource group. + if (subGroups.isEmpty()) { + return queuedQueries.size(); + } + + // For internal groups, when no queries can run, only queries that could run on its subgroups are waiting for resources on this group. + int waitingQueuedQueries = 0; + for (InternalResourceGroup subGroup : subGroups.values()) { + if (subGroup.canRunMore()) { + waitingQueuedQueries += min(subGroup.getQueuedQueries(), subGroup.getMaxRunningQueries() - subGroup.getRunningQueries()); + } + } + + return waitingQueuedQueries; + } + } + @Override public DataSize getSoftMemoryLimit() { @@ -404,6 +474,38 @@ public void setJmxExport(boolean export) jmxExportListener.accept(this, export); } + @Override + public Duration getQueuedTimeLimit() + { + synchronized (root) { + return queuedTimeLimit; + } + } + + @Override + public void setQueuedTimeLimit(Duration queuedTimeLimit) + { + synchronized (root) { + this.queuedTimeLimit = queuedTimeLimit; + } + } + + @Override + public Duration getRunningTimeLimit() + { + synchronized (root) { + return runningTimeLimit; + } + } + + @Override + public void setRunningTimeLimit(Duration runningTimeLimit) + { + synchronized (root) { + this.runningTimeLimit = runningTimeLimit; + } + } + public InternalResourceGroup getOrCreateSubGroup(String name) { requireNonNull(name, "name is null"); @@ -630,6 +732,28 @@ protected boolean internalStartNext() } } + protected void enforceTimeLimits() + { + checkState(Thread.holdsLock(root), "Must hold lock to enforce time limits"); + synchronized (root) { + for (InternalResourceGroup group : subGroups.values()) { + group.enforceTimeLimits(); + } + for (QueryExecution query : runningQueries) { + Duration runningTime = query.getQueryInfo().getQueryStats().getExecutionTime(); + if (runningQueries.contains(query) && runningTime != null && runningTime.compareTo(runningTimeLimit) > 0) { + query.fail(new PrestoException(EXCEEDED_TIME_LIMIT, "query exceeded resource group runtime limit")); + } + } + for (QueryExecution query : queuedQueries) { + Duration elapsedTime = query.getQueryInfo().getQueryStats().getElapsedTime(); + if (queuedQueries.contains(query) && elapsedTime != null && elapsedTime.compareTo(queuedTimeLimit) > 0) { + query.fail(new PrestoException(EXCEEDED_TIME_LIMIT, "query exceeded resource group queued time limit")); + } + } + } + } + private static int getSubGroupSchedulingPriority(SchedulingPolicy policy, InternalResourceGroup group) { if (policy == QUERY_PRIORITY) { @@ -693,7 +817,7 @@ private boolean canRunMore() double penalty = (cpuUsageMillis - softCpuLimitMillis) / (double) (hardCpuLimitMillis - softCpuLimitMillis); maxRunning = (int) Math.floor(maxRunning * (1 - penalty)); // Always penalize by at least one - maxRunning = Math.min(maxRunningQueries - 1, maxRunning); + maxRunning = min(maxRunningQueries - 1, maxRunning); // Always allow at least one running query maxRunning = Math.max(1, maxRunning); } @@ -741,6 +865,7 @@ public RootInternalResourceGroup(String name, BiConsumer rootGroups = new CopyOnWriteArrayList<>(); private final ConcurrentMap groups = new ConcurrentHashMap<>(); private final AtomicReference configurationManager = new AtomicReference<>(); - private final ClusterMemoryPoolManager memoryPoolManager; + private final ResourceGroupConfigurationManagerContext configurationManagerContext; private final MBeanExporter exporter; private final AtomicBoolean started = new AtomicBoolean(); private final AtomicLong lastCpuQuotaGenerationNanos = new AtomicLong(System.nanoTime()); private final Map configurationManagerFactories = new ConcurrentHashMap<>(); @Inject - public InternalResourceGroupManager(LegacyResourceGroupConfigurationManagerFactory builtinFactory, ClusterMemoryPoolManager memoryPoolManager, MBeanExporter exporter) + public InternalResourceGroupManager(LegacyResourceGroupConfigurationManagerFactory builtinFactory, ClusterMemoryPoolManager memoryPoolManager, NodeInfo nodeInfo, MBeanExporter exporter) { this.exporter = requireNonNull(exporter, "exporter is null"); - this.memoryPoolManager = requireNonNull(memoryPoolManager, "memoryPoolManager is null"); + this.configurationManagerContext = new ResourceGroupConfigurationManagerContextInstance(memoryPoolManager, nodeInfo.getEnvironment()); requireNonNull(builtinFactory, "builtinFactory is null"); addConfigurationManagerFactory(builtinFactory); } @@ -101,6 +105,15 @@ public ResourceGroupInfo getResourceGroupInfo(ResourceGroupId id) return groups.get(id).getInfo(); } + @Override + public ResourceGroupStateInfo getResourceGroupStateInfo(ResourceGroupId id) + { + if (!groups.containsKey(id)) { + throw new NoSuchElementException(); + } + return groups.get(id).getStateInfo(); + } + @Override public void submit(Statement statement, QueryExecution queryExecution, Executor executor) { @@ -154,7 +167,7 @@ public void setConfigurationManager(String name, Map properties) ResourceGroupConfigurationManagerFactory configurationManagerFactory = configurationManagerFactories.get(name); checkState(configurationManagerFactory != null, "Resource group configuration manager %s is not registered", name); - ResourceGroupConfigurationManager configurationManager = configurationManagerFactory.create(ImmutableMap.copyOf(properties), () -> memoryPoolManager); + ResourceGroupConfigurationManager configurationManager = configurationManagerFactory.create(ImmutableMap.copyOf(properties), configurationManagerContext); checkState(this.configurationManager.compareAndSet(null, configurationManager), "configurationManager already set"); log.info("-- Loaded resource group configuration manager %s --", name); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java index b51b129e8bb0c..9fadfda21f9e1 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/NoOpResourceGroupManager.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.resourceGroups; import com.facebook.presto.execution.QueryExecution; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManagerFactory; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; @@ -39,6 +40,12 @@ public ResourceGroupInfo getResourceGroupInfo(ResourceGroupId id) throw new UnsupportedOperationException(); } + @Override + public ResourceGroupStateInfo getResourceGroupStateInfo(ResourceGroupId id) + { + throw new UnsupportedOperationException(); + } + @Override public void addConfigurationManagerFactory(ResourceGroupConfigurationManagerFactory factory) { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupConfigurationManagerContextInstance.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupConfigurationManagerContextInstance.java new file mode 100644 index 0000000000000..04530eca7ce4d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupConfigurationManagerContextInstance.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.resourceGroups; + +import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; +import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManagerContext; + +import static java.util.Objects.requireNonNull; + +public class ResourceGroupConfigurationManagerContextInstance + implements ResourceGroupConfigurationManagerContext +{ + private final ClusterMemoryPoolManager memoryPoolManager; + private final String environment; + + public ResourceGroupConfigurationManagerContextInstance(ClusterMemoryPoolManager memoryPoolManager, String environment) + { + this.memoryPoolManager = requireNonNull(memoryPoolManager, "memoryPoolManager is null"); + this.environment = requireNonNull(environment, "environment is null"); + } + + @Override + public ClusterMemoryPoolManager getMemoryPoolManager() + { + return memoryPoolManager; + } + + @Override + public String getEnvironment() + { + return environment; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java index 8749f7ab876b1..b0e17c36d5115 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.resourceGroups; import com.facebook.presto.execution.QueryQueueManager; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupConfigurationManagerFactory; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; @@ -23,6 +24,8 @@ public interface ResourceGroupManager { ResourceGroupInfo getResourceGroupInfo(ResourceGroupId id); + ResourceGroupStateInfo getResourceGroupStateInfo(ResourceGroupId id); + void addConfigurationManagerFactory(ResourceGroupConfigurationManagerFactory factory); void loadConfigurationManager() throws Exception; diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/StochasticPriorityQueue.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/StochasticPriorityQueue.java index 4619e877b6978..bc75e6b79d17e 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/StochasticPriorityQueue.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/StochasticPriorityQueue.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.resourceGroups; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import java.util.Optional; import java.util.concurrent.ThreadLocalRandom; @@ -138,6 +139,13 @@ public boolean isEmpty() return index.isEmpty(); } + @Override + public Iterator iterator() + { + // Since poll() is not deterministic ordering is not required + return index.keySet().iterator(); + } + private static final class Node { private Optional> parent; diff --git a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/UpdateablePriorityQueue.java b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/UpdateablePriorityQueue.java index 3181ef10a7c54..d98933c704599 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/UpdateablePriorityQueue.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/resourceGroups/UpdateablePriorityQueue.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.resourceGroups; interface UpdateablePriorityQueue + extends Iterable { boolean addOrUpdate(E element, int priority); diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java index 6a514bb9e8ad2..3750de9b05810 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java @@ -169,7 +169,7 @@ static List> extractPhases(Collection fragment } private static class Visitor - extends PlanVisitor> + extends PlanVisitor, PlanFragmentId> { private final Map fragments; private final DirectedGraph graph; @@ -184,7 +184,13 @@ public Visitor(Collection fragments, DirectedGraph processFragment(PlanFragmentId planFragmentId) { - return fragmentSources.computeIfAbsent(planFragmentId, fragmentId -> processFragment(fragments.get(fragmentId))); + if (fragmentSources.containsKey(planFragmentId)) { + return fragmentSources.get(planFragmentId); + } + + Set fragment = processFragment(fragments.get(planFragmentId)); + fragmentSources.put(planFragmentId, fragment); + return fragment; } private Set processFragment(PlanFragment fragment) diff --git a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java index ddf3e0f20f67e..282658bf507d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java @@ -22,7 +22,6 @@ import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import com.google.common.util.concurrent.FutureCallback; @@ -36,11 +35,12 @@ import static com.facebook.presto.execution.scheduler.ScheduleResult.BlockedReason.SPLIT_QUEUES_FULL; import static com.facebook.presto.execution.scheduler.ScheduleResult.BlockedReason.WAITING_FOR_SOURCE; +import static com.facebook.presto.spi.StandardErrorCode.NO_NODES_AVAILABLE; +import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; import static io.airlift.concurrent.MoreFutures.getFutureValue; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class SourcePartitionedScheduler @@ -86,12 +86,10 @@ public synchronized ScheduleResult schedule() // try to get the next batch if necessary if (pendingSplits.isEmpty()) { if (batchFuture == null) { - if (!splitSource.isFinished()) { - batchFuture = splitSource.getNextBatch(splitBatchSize); - } - else { + if (splitSource.isFinished()) { return handleNoMoreSplits(); } + batchFuture = splitSource.getNextBatch(splitBatchSize); long start = System.nanoTime(); Futures.addCallback(batchFuture, new FutureCallback>() @@ -158,11 +156,8 @@ private ScheduleResult handleNoMoreSplits() state = State.FINISHED; splitSource.close(); return new ScheduleResult(true, ImmutableSet.of(), 0); - default: - throw new IllegalStateException( - format("SourcePartitionedScheduler expected to be in INITIALIZED or SPLITS_SCHEDULED state" + - " but is in [%s]", state)); } + throw new IllegalStateException("SourcePartitionedScheduler expected to be in INITIALIZED or SPLITS_SCHEDULED state but is in " + state); } @Override @@ -175,11 +170,14 @@ private ScheduleResult scheduleEmptySplit() { state = State.SPLITS_SCHEDULED; + List nodes = splitPlacementPolicy.allNodes(); + checkCondition(!nodes.isEmpty(), NO_NODES_AVAILABLE, "No nodes available to run query"); + Node node = nodes.iterator().next(); + Split emptySplit = new Split( splitSource.getConnectorId(), splitSource.getTransactionHandle(), new EmptySplit(splitSource.getConnectorId())); - Node node = Iterables.getLast(splitPlacementPolicy.allNodes()); Set emptyTask = assignSplits(ImmutableMultimap.of(node, emptySplit)); return new ScheduleResult(false, emptyTask, 1); } diff --git a/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java b/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java index 332a0f2a917fa..17f38ab539fdb 100644 --- a/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java +++ b/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.failureDetector; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.spi.HostAddress; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; @@ -42,7 +43,6 @@ import java.net.SocketTimeoutException; import java.net.URI; import java.net.URISyntaxException; -import java.net.UnknownHostException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -88,6 +88,7 @@ public class HeartbeatFailureDetector private final boolean isEnabled; private final Duration warmupInterval; private final Duration gcGraceInterval; + private final boolean httpsRequired; private final AtomicBoolean started = new AtomicBoolean(); @@ -95,25 +96,28 @@ public class HeartbeatFailureDetector public HeartbeatFailureDetector( @ServiceType("presto") ServiceSelector selector, @ForFailureDetector HttpClient httpClient, - FailureDetectorConfig config, - NodeInfo nodeInfo) + FailureDetectorConfig failureDetectorConfig, + NodeInfo nodeInfo, + InternalCommunicationConfig internalCommunicationConfig) { requireNonNull(selector, "selector is null"); requireNonNull(httpClient, "httpClient is null"); requireNonNull(nodeInfo, "nodeInfo is null"); - requireNonNull(config, "config is null"); - checkArgument(config.getHeartbeatInterval().toMillis() >= 1, "heartbeat interval must be >= 1ms"); + requireNonNull(failureDetectorConfig, "config is null"); + checkArgument(failureDetectorConfig.getHeartbeatInterval().toMillis() >= 1, "heartbeat interval must be >= 1ms"); this.selector = selector; this.httpClient = httpClient; this.nodeInfo = nodeInfo; - this.failureRatioThreshold = config.getFailureRatioThreshold(); - this.heartbeat = config.getHeartbeatInterval(); - this.warmupInterval = config.getWarmupInterval(); - this.gcGraceInterval = config.getExpirationGraceInterval(); + this.failureRatioThreshold = failureDetectorConfig.getFailureRatioThreshold(); + this.heartbeat = failureDetectorConfig.getHeartbeatInterval(); + this.warmupInterval = failureDetectorConfig.getWarmupInterval(); + this.gcGraceInterval = failureDetectorConfig.getExpirationGraceInterval(); - this.isEnabled = config.isEnabled(); + this.isEnabled = failureDetectorConfig.isEnabled(); + + this.httpsRequired = internalCommunicationConfig.isHttpsRequired(); } @PostConstruct @@ -162,11 +166,12 @@ public State getState(HostAddress hostAddress) } Exception lastFailureException = task.getStats().getLastFailureException(); - if (lastFailureException instanceof SocketTimeoutException || lastFailureException instanceof UnknownHostException) { + if (lastFailureException instanceof ConnectException) { return GONE; } - if (lastFailureException instanceof ConnectException) { + if (lastFailureException instanceof SocketTimeoutException) { + // TODO: distinguish between process unresponsiveness (e.g GC pause) and host reboot return UNRESPONSIVE; } @@ -251,18 +256,16 @@ void updateMonitoredServices() } } - private static URI getHttpUri(ServiceDescriptor service) + private URI getHttpUri(ServiceDescriptor descriptor) { - try { - String uri = service.getProperties().get("http"); - if (uri != null) { - return new URI(uri); + String url = descriptor.getProperties().get(httpsRequired ? "https" : "http"); + if (url != null) { + try { + return new URI(url); + } + catch (URISyntaxException ignored) { } } - catch (URISyntaxException e) { - // ignore, not a valid http uri - } - return null; } diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java index fe93133fa264d..ca55deb9f1d64 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java @@ -17,6 +17,7 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.connector.system.GlobalSystemConnector; import com.facebook.presto.failureDetector.FailureDetector; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.NodeState; import com.google.common.base.Splitter; @@ -56,7 +57,6 @@ import static com.google.common.collect.Sets.difference; import static io.airlift.concurrent.Threads.threadsNamed; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static java.util.Arrays.asList; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; @@ -76,6 +76,7 @@ public final class DiscoveryNodeManager private final ConcurrentHashMap nodeStates = new ConcurrentHashMap<>(); private final HttpClient httpClient; private final ScheduledExecutorService nodeStateUpdateExecutor; + private final boolean httpsRequired; @GuardedBy("this") private SetMultimap activeNodesByConnectorId; @@ -97,7 +98,8 @@ public DiscoveryNodeManager( NodeInfo nodeInfo, FailureDetector failureDetector, NodeVersion expectedNodeVersion, - @ForNodeManager HttpClient httpClient) + @ForNodeManager HttpClient httpClient, + InternalCommunicationConfig internalCommunicationConfig) { this.serviceSelector = requireNonNull(serviceSelector, "serviceSelector is null"); this.nodeInfo = requireNonNull(nodeInfo, "nodeInfo is null"); @@ -105,6 +107,7 @@ public DiscoveryNodeManager( this.expectedNodeVersion = requireNonNull(expectedNodeVersion, "expectedNodeVersion is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.nodeStateUpdateExecutor = newSingleThreadScheduledExecutor(threadsNamed("node-state-poller-%s")); + this.httpsRequired = internalCommunicationConfig.isHttpsRequired(); this.currentNode = refreshNodesInternal(); } @@ -323,16 +326,14 @@ public synchronized Set getCoordinators() return coordinators; } - private static URI getHttpUri(ServiceDescriptor descriptor) + private URI getHttpUri(ServiceDescriptor descriptor) { - for (String type : asList("http", "https")) { - String url = descriptor.getProperties().get(type); - if (url != null) { - try { - return new URI(url); - } - catch (URISyntaxException ignored) { - } + String url = descriptor.getProperties().get(httpsRequired ? "https" : "http"); + if (url != null) { + try { + return new URI(url); + } + catch (URISyntaxException ignored) { } } return null; diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index ad08fdd9fc6f9..5689c20efccf1 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -85,11 +85,13 @@ import com.facebook.presto.operator.scalar.DateTimeFunctions; import com.facebook.presto.operator.scalar.EmptyMapConstructor; import com.facebook.presto.operator.scalar.FailureFunction; +import com.facebook.presto.operator.scalar.GroupingOperationFunction; import com.facebook.presto.operator.scalar.HyperLogLogFunctions; import com.facebook.presto.operator.scalar.JoniRegexpCasts; import com.facebook.presto.operator.scalar.JoniRegexpFunctions; import com.facebook.presto.operator.scalar.JsonFunctions; import com.facebook.presto.operator.scalar.JsonOperators; +import com.facebook.presto.operator.scalar.ListLiteralCast; import com.facebook.presto.operator.scalar.MapCardinalityFunction; import com.facebook.presto.operator.scalar.MapDistinctFromOperator; import com.facebook.presto.operator.scalar.MapEqualOperator; @@ -100,6 +102,7 @@ import com.facebook.presto.operator.scalar.MapValues; import com.facebook.presto.operator.scalar.MathFunctions; import com.facebook.presto.operator.scalar.Re2JRegexpFunctions; +import com.facebook.presto.operator.scalar.RepeatFunction; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.SequenceFunction; import com.facebook.presto.operator.scalar.StringFunctions; @@ -150,6 +153,7 @@ import com.facebook.presto.type.TimestampOperators; import com.facebook.presto.type.TimestampWithTimeZoneOperators; import com.facebook.presto.type.TinyintOperators; +import com.facebook.presto.type.TypeRegistry; import com.facebook.presto.type.UnknownOperators; import com.facebook.presto.type.VarbinaryOperators; import com.facebook.presto.type.VarcharOperators; @@ -440,6 +444,7 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) .aggregate(RealCorrelationAggregation.class) .aggregate(BitwiseOrAggregation.class) .aggregate(BitwiseAndAggregation.class) + .scalar(RepeatFunction.class) .scalars(SequenceFunction.class) .scalars(StringFunctions.class) .scalars(VarbinaryFunctions.class) @@ -523,6 +528,8 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) .scalar(MapToMapCast.class) .scalars(EmptyMapConstructor.class) .scalar(TypeOfFunction.class) + .scalars(ListLiteralCast.class) + .scalars(GroupingOperationFunction.class) .function(ZIP_WITH_FUNCTION) .functions(ZIP_FUNCTIONS) .functions(ARRAY_JOIN, ARRAY_JOIN_WITH_NULL_REPLACEMENT) @@ -534,7 +541,7 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) .function(ARRAY_FLATTEN_FUNCTION) .function(ARRAY_CONCAT_FUNCTION) .functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_TO_JSON, JSON_TO_ARRAY) - .functions(new MapSubscriptOperator(featuresConfig.isLegacyMapSubscript())) + .functions(new MapSubscriptOperator(featuresConfig.isLegacyMapSubscript(), featuresConfig.isNewMapBlock())) .functions(MAP_CONSTRUCTOR, MAP_TO_JSON, JSON_TO_MAP) .functions(MAP_AGG, MULTIMAP_AGG, MAP_UNION) .functions(DECIMAL_TO_VARCHAR_CAST, DECIMAL_TO_INTEGER_CAST, DECIMAL_TO_BIGINT_CAST, DECIMAL_TO_DOUBLE_CAST, DECIMAL_TO_REAL_CAST, DECIMAL_TO_BOOLEAN_CAST, DECIMAL_TO_TINYINT_CAST, DECIMAL_TO_SMALLINT_CAST) @@ -584,6 +591,10 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key) } addFunctions(builder.getFunctions()); + + if (typeManager instanceof TypeRegistry) { + ((TypeRegistry) typeManager).setFunctionRegistry(this); + } } public final synchronized void addFunctions(List functions) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/HandleResolver.java b/presto-main/src/main/java/com/facebook/presto/metadata/HandleResolver.java index d69728cdbab92..4a45072216238 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/HandleResolver.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/HandleResolver.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.metadata; -import com.facebook.presto.connector.EmptySplitHandleResolver; import com.facebook.presto.connector.informationSchema.InformationSchemaHandleResolver; import com.facebook.presto.connector.system.SystemHandleResolver; import com.facebook.presto.spi.ColumnHandle; @@ -26,6 +25,7 @@ import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.split.EmptySplitHandleResolver; import javax.inject.Inject; diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java index 2888bbe2ff48e..427350a482983 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; @@ -74,6 +75,11 @@ public interface Metadata */ TableMetadata getTableMetadata(Session session, TableHandle tableHandle); + /** + * Return statistics for specified table for given filtering contraint. + */ + TableStatistics getTableStatistics(Session session, TableHandle tableHandle, Constraint constraint); + /** * Get the names that match the specified table prefix (never null). */ diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java index 1caf4a642b844..7e878d1ed9dbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -44,6 +44,7 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; @@ -339,6 +340,14 @@ public TableMetadata getTableMetadata(Session session, TableHandle tableHandle) return new TableMetadata(connectorId, tableMetadata); } + @Override + public TableStatistics getTableStatistics(Session session, TableHandle tableHandle, Constraint constraint) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + return metadata.getTableStatistics(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), constraint); + } + @Override public Map getColumnHandles(Session session, TableHandle tableHandle) { diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java index a2fe45843ab61..34d5bc9860850 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java @@ -20,17 +20,17 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.session.PropertyMetadata; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DoubleType; import com.facebook.presto.spi.type.IntegerType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.planner.ParameterRewriter; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import io.airlift.json.JsonCodec; diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java b/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java index 7718424b3b36d..cdff31ebf907d 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java @@ -737,7 +737,7 @@ public SolverReturnStatus update(BoundVariables.Builder bindings) if (!appendTypeRelationshipConstraintSolver(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), false)) { return SolverReturnStatus.UNSOLVABLE; } - if (!appendConstraintSolvers(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), false)) { + if (!appendConstraintSolvers(constraintsBuilder, formalLambdaReturnTypeSignature, new TypeSignatureProvider(actualReturnType.getTypeSignature()), allowCoercion)) { return SolverReturnStatus.UNSOLVABLE; } SolverReturnStatusMerger statusMerger = new SolverReturnStatusMerger(); @@ -834,7 +834,7 @@ public SolverReturnStatus update(BoundVariables.Builder bindings) TypeSignature boundSignature = applyBoundVariables(superTypeSignature, bindings.build()); - return satisfiesCoercion(allowCoercion, actualType, boundSignature) ? SolverReturnStatus.UNCHANGED_SATISFIED : SolverReturnStatus.UNSOLVABLE; + return satisfiesCoercion(allowCoercion, actualType, boundSignature) ? SolverReturnStatus.UNCHANGED_SATISFIED : SolverReturnStatus.UNCHANGED_NOT_SATISFIED; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java index 760300e38a9d9..fdf6d939b210f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java @@ -14,10 +14,9 @@ package com.facebook.presto.operator; import com.facebook.presto.spi.Page; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; -import java.util.Optional; -import java.util.function.Function; import static io.airlift.slice.SizeOf.sizeOf; import static java.util.Objects.requireNonNull; @@ -25,11 +24,14 @@ public final class ArrayPositionLinks implements PositionLinks { - public static class Builder implements PositionLinks.Builder + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ArrayPositionLinks.class).instanceSize(); + + public static class FactoryBuilder implements PositionLinks.FactoryBuilder { private final int[] positionLinks; + private int size; - private Builder(int size) + private FactoryBuilder(int size) { positionLinks = new int[size]; Arrays.fill(positionLinks, -1); @@ -38,15 +40,22 @@ private Builder(int size) @Override public int link(int left, int right) { + size++; positionLinks[left] = right; return left; } @Override - public Function, PositionLinks> build() + public Factory build() { return filterFunction -> new ArrayPositionLinks(positionLinks); } + + @Override + public int size() + { + return size; + } } private final int[] positionLinks; @@ -56,9 +65,9 @@ private ArrayPositionLinks(int[] positionLinks) this.positionLinks = requireNonNull(positionLinks, "positionLinks is null"); } - public static Builder builder(int size) + public static FactoryBuilder builder(int size) { - return new Builder(size); + return new FactoryBuilder(size); } @Override @@ -76,6 +85,6 @@ public int next(int position, int probePosition, Page allProbeChannelsPage) @Override public long getSizeInBytes() { - return sizeOf(positionLinks); + return INSTANCE_SIZE + sizeOf(positionLinks); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java b/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java index c15b8a35eef33..cb0e87a2f8336 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ArrayUnnester.java @@ -16,8 +16,8 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import javax.annotation.Nullable; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ChannelSet.java b/presto-main/src/main/java/com/facebook/presto/operator/ChannelSet.java index 85b8936663cad..8178a899c5d17 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ChannelSet.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ChannelSet.java @@ -53,6 +53,11 @@ public int size() return hash.getGroupCount(); } + public boolean isEmpty() + { + return size() == 0; + } + public boolean containsNull() { return containsNull; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/Driver.java b/presto-main/src/main/java/com/facebook/presto/operator/Driver.java index 431f19b0ce4e1..ac89987740a2e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/Driver.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/Driver.java @@ -19,7 +19,6 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.UpdatablePageSource; -import com.facebook.presto.split.EmptySplit; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.base.Throwables; import com.google.common.base.VerifyException; @@ -207,10 +206,6 @@ private void processNewSources() for (ScheduledSplit newSplit : newSplits) { Split split = newSplit.getSplit(); - if (split.getConnectorSplit() instanceof EmptySplit) { - continue; - } - Supplier> pageSource = sourceOperator.addSplit(split); deleteOperator.ifPresent(deleteOperator -> deleteOperator.setPageSource(pageSource)); } @@ -275,10 +270,10 @@ private ListenableFuture processInternal() // check if operator is blocked Operator current = operators.get(0); - ListenableFuture blocked = isBlocked(current); - if (!blocked.isDone()) { - current.getOperatorContext().recordBlocked(blocked); - return blocked; + Optional> blocked = getBlockedFuture(current); + if (blocked.isPresent()) { + current.getOperatorContext().recordBlocked(blocked.get()); + return blocked.get(); } // there is only one operator so just finish it @@ -293,16 +288,13 @@ private ListenableFuture processInternal() Operator current = operators.get(i); Operator next = operators.get(i + 1); - // skip blocked operators - if (!isBlocked(current).isDone()) { - continue; - } - if (!isBlocked(next).isDone()) { + // skip blocked operator + if (getBlockedFuture(current).isPresent()) { continue; } - // if the current operator is not finished and next operator needs input... - if (!current.isFinished() && next.needsInput()) { + // if the current operator is not finished and next operator isn't blocked and needs input... + if (!current.isFinished() && !getBlockedFuture(next).isPresent() && next.needsInput()) { // get an output page from current operator current.getOperatorContext().startIntervalTimer(); Page page = current.getOutput(); @@ -335,10 +327,10 @@ private ListenableFuture processInternal() List blockedOperators = new ArrayList<>(); List> blockedFutures = new ArrayList<>(); for (Operator operator : operators) { - ListenableFuture blocked = isBlocked(operator); - if (!blocked.isDone()) { + Optional> blocked = getBlockedFuture(operator); + if (blocked.isPresent()) { blockedOperators.add(operator); - blockedFutures.add(blocked); + blockedFutures.add(blocked.get()); } } @@ -462,13 +454,17 @@ private void destroyIfNecessary() } } - private static ListenableFuture isBlocked(Operator operator) + private static Optional> getBlockedFuture(Operator operator) { ListenableFuture blocked = operator.isBlocked(); - if (blocked.isDone()) { - blocked = operator.getOperatorContext().isWaitingForMemory(); + if (!blocked.isDone()) { + return Optional.of(blocked); + } + blocked = operator.getOperatorContext().isWaitingForMemory(); + if (!blocked.isDone()) { + return Optional.of(blocked); } - return blocked; + return Optional.empty(); } private static Throwable addSuppressedException(Throwable inFlightException, Throwable newException, String message, Object... args) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java index 0e01b8b88f1b0..f53679de0a47d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ExplainAnalyzeOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.QueryPerformanceFetcher; import com.facebook.presto.execution.StageId; @@ -44,14 +45,16 @@ public static class ExplainAnalyzeOperatorFactory private final PlanNodeId planNodeId; private final QueryPerformanceFetcher queryPerformanceFetcher; private final Metadata metadata; + private final CostCalculator costCalculator; private boolean closed; - public ExplainAnalyzeOperatorFactory(int operatorId, PlanNodeId planNodeId, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata) + public ExplainAnalyzeOperatorFactory(int operatorId, PlanNodeId planNodeId, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.queryPerformanceFetcher = requireNonNull(queryPerformanceFetcher, "queryPerformanceFetcher is null"); this.metadata = requireNonNull(metadata, "metadata is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); } @Override @@ -65,7 +68,7 @@ public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, ExplainAnalyzeOperator.class.getSimpleName()); - return new ExplainAnalyzeOperator(operatorContext, queryPerformanceFetcher, metadata); + return new ExplainAnalyzeOperator(operatorContext, queryPerformanceFetcher, metadata, costCalculator); } @Override @@ -77,21 +80,23 @@ public void close() @Override public OperatorFactory duplicate() { - return new ExplainAnalyzeOperatorFactory(operatorId, planNodeId, queryPerformanceFetcher, metadata); + return new ExplainAnalyzeOperatorFactory(operatorId, planNodeId, queryPerformanceFetcher, metadata, costCalculator); } } private final OperatorContext operatorContext; private final QueryPerformanceFetcher queryPerformanceFetcher; private final Metadata metadata; + private final CostCalculator costCalculator; private boolean finishing; private boolean outputConsumed; - public ExplainAnalyzeOperator(OperatorContext operatorContext, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata) + public ExplainAnalyzeOperator(OperatorContext operatorContext, QueryPerformanceFetcher queryPerformanceFetcher, Metadata metadata, CostCalculator costCalculator) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.queryPerformanceFetcher = requireNonNull(queryPerformanceFetcher, "queryPerformanceFetcher is null"); this.metadata = requireNonNull(metadata, "metadata is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); } @Override @@ -146,7 +151,7 @@ public Page getOutput() return null; } - String plan = textDistributedPlan(queryInfo.getOutputStage().get(), metadata, operatorContext.getSession()); + String plan = textDistributedPlan(queryInfo.getOutputStage().get(), metadata, costCalculator, operatorContext.getSession()); BlockBuilder builder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 1); VARCHAR.writeString(builder, plan); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java b/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java index 6e0da9db77ed2..be2c6588e0b0d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/GroupByIdBlock.java @@ -20,6 +20,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.base.MoreObjects.toStringHelper; @@ -57,7 +58,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { return block.getRegionSizeInBytes(positionOffset, length); } @@ -171,17 +172,24 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return block.getSizeInBytes(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return INSTANCE_SIZE + block.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(block, block.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public BlockEncoding getEncoding() { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashSemiJoinOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashSemiJoinOperator.java index 3a023d8654c41..2a11cf9c9a946 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashSemiJoinOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashSemiJoinOperator.java @@ -15,7 +15,6 @@ import com.facebook.presto.operator.SetBuilderOperator.SetSupplier; import com.facebook.presto.spi.Page; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.Type; @@ -25,7 +24,6 @@ import java.util.List; -import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -176,9 +174,12 @@ public void addInput(Page page) // update hashing strategy to use probe cursor for (int position = 0; position < page.getPositionCount(); position++) { if (probeJoinPage.getBlock(0).isNull(position)) { - throw new PrestoException( - NOT_SUPPORTED, - "NULL values are not allowed on the probe side of SemiJoin operator. See the query plan for details."); + if (channelSet.isEmpty()) { + BOOLEAN.writeBoolean(blockBuilder, false); + } + else { + blockBuilder.appendNull(); + } } else { boolean contains = channelSet.contains(position, probeJoinPage); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java index 0537bfb6beefd..1e9ebea31c8c0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinHash.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; +import org.openjdk.jol.info.ClassLayout; import javax.annotation.Nullable; @@ -27,6 +28,7 @@ public final class JoinHash implements LookupSource { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(JoinHash.class).instanceSize(); private final PagesHash pagesHash; // we unwrap Optional to actual verifier or null in constructor for performance reasons @@ -34,13 +36,16 @@ public final class JoinHash @Nullable private final JoinFilterFunction filterFunction; + // we unwrap Optional to actual position links or null in constructor for performance reasons + // we do quick check for `positionLinks == null` to avoid calls to positionLinks + @Nullable private final PositionLinks positionLinks; - public JoinHash(PagesHash pagesHash, Optional filterFunction, PositionLinks positionLinks) + public JoinHash(PagesHash pagesHash, Optional filterFunction, Optional positionLinks) { this.pagesHash = requireNonNull(pagesHash, "pagesHash is null"); this.filterFunction = requireNonNull(filterFunction, "filterFunction can not be null").orElse(null); - this.positionLinks = requireNonNull(positionLinks, "positionLinks is null"); + this.positionLinks = requireNonNull(positionLinks, "positionLinks is null").orElse(null); } @Override @@ -58,20 +63,20 @@ public int getJoinPositionCount() @Override public long getInMemorySizeInBytes() { - return pagesHash.getInMemorySizeInBytes() + positionLinks.getSizeInBytes(); + return INSTANCE_SIZE + pagesHash.getInMemorySizeInBytes() + (positionLinks == null ? 0 : positionLinks.getSizeInBytes()); } @Override public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage) { - int addressIndex = pagesHash.getAddressIndex(position, hashChannelsPage, allChannelsPage); + int addressIndex = pagesHash.getAddressIndex(position, hashChannelsPage); return startJoinPosition(addressIndex, position, allChannelsPage); } @Override public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage, long rawHash) { - int addressIndex = pagesHash.getAddressIndex(position, hashChannelsPage, allChannelsPage, rawHash); + int addressIndex = pagesHash.getAddressIndex(position, hashChannelsPage, rawHash); return startJoinPosition(addressIndex, position, allChannelsPage); } @@ -80,12 +85,18 @@ private long startJoinPosition(int currentJoinPosition, int probePosition, Page if (currentJoinPosition == -1) { return -1; } + if (positionLinks == null) { + return currentJoinPosition; + } return positionLinks.start(currentJoinPosition, probePosition, allProbeChannelsPage); } @Override public final long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) { + if (positionLinks == null) { + return -1; + } return positionLinks.next(toIntExact(currentJoinPosition), probePosition, allProbeChannelsPage); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java index c8b322a610797..286da876d8af9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java @@ -16,16 +16,12 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.block.Block; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; -import it.unimi.dsi.fastutil.ints.IntComparator; import it.unimi.dsi.fastutil.longs.LongArrayList; import java.util.List; import java.util.Optional; -import java.util.function.Function; import static com.facebook.presto.SystemSessionProperties.isFastInequalityJoin; -import static com.facebook.presto.operator.SyntheticAddress.decodePosition; -import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static java.util.Objects.requireNonNull; public class JoinHashSupplier @@ -35,7 +31,7 @@ public class JoinHashSupplier private final PagesHash pagesHash; private final LongArrayList addresses; private final List> channels; - private final Function, PositionLinks> positionLinks; + private final Optional positionLinks; private final Optional filterFunctionFactory; public JoinHashSupplier( @@ -51,20 +47,21 @@ public JoinHashSupplier( this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null"); requireNonNull(pagesHashStrategy, "pagesHashStrategy is null"); - PositionLinks.Builder positionLinksBuilder; + PositionLinks.FactoryBuilder positionLinksFactoryBuilder; if (filterFunctionFactory.isPresent() && filterFunctionFactory.get().getSortChannel().isPresent() && isFastInequalityJoin(session)) { - positionLinksBuilder = SortedPositionLinks.builder( + positionLinksFactoryBuilder = SortedPositionLinks.builder( addresses.size(), - new PositionComparator(pagesHashStrategy, addresses)); + pagesHashStrategy, + addresses); } else { - positionLinksBuilder = ArrayPositionLinks.builder(addresses.size()); + positionLinksFactoryBuilder = ArrayPositionLinks.builder(addresses.size()); } - this.pagesHash = new PagesHash(addresses, pagesHashStrategy, positionLinksBuilder); - this.positionLinks = positionLinksBuilder.build(); + this.pagesHash = new PagesHash(addresses, pagesHashStrategy, positionLinksFactoryBuilder); + this.positionLinks = positionLinksFactoryBuilder.isEmpty() ? Optional.empty() : Optional.of(positionLinksFactoryBuilder.build()); } @Override @@ -89,39 +86,6 @@ public JoinHash get() return new JoinHash( pagesHash, filterFunction, - positionLinks.apply(filterFunction)); - } - - public static class PositionComparator - implements IntComparator - { - private final PagesHashStrategy pagesHashStrategy; - private final LongArrayList addresses; - - public PositionComparator(PagesHashStrategy pagesHashStrategy, LongArrayList addresses) - { - this.pagesHashStrategy = pagesHashStrategy; - this.addresses = addresses; - } - - @Override - public int compare(int leftPosition, int rightPosition) - { - long leftPageAddress = addresses.getLong(leftPosition); - int leftBlockIndex = decodeSliceIndex(leftPageAddress); - int leftBlockPosition = decodePosition(leftPageAddress); - - long rightPageAddress = addresses.getLong(rightPosition); - int rightBlockIndex = decodeSliceIndex(rightPageAddress); - int rightBlockPosition = decodePosition(rightPageAddress); - - return pagesHashStrategy.compare(leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition); - } - - @Override - public int compare(Integer leftPosition, Integer rightPosition) - { - return compare(leftPosition.intValue(), rightPosition.intValue()); - } + positionLinks.map(links -> links.create(filterFunction))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinOperatorInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinOperatorInfo.java new file mode 100644 index 0000000000000..c82fb0e3da94d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinOperatorInfo.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.operator.LookupJoinOperators.JoinType; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static com.facebook.presto.operator.JoinStatisticsCounter.HISTOGRAM_BUCKETS; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +public class JoinOperatorInfo + implements Mergeable, OperatorInfo +{ + private final JoinType joinType; + private final long[] logHistogramProbes; + private final long[] logHistogramOutput; + + public static JoinOperatorInfo createJoinOperatorInfo(JoinType joinType, long[] logHistogramCounters) + { + long[] logHistogramProbes = new long[HISTOGRAM_BUCKETS]; + long[] logHistogramOutput = new long[HISTOGRAM_BUCKETS]; + for (int i = 0; i < HISTOGRAM_BUCKETS; i++) { + logHistogramProbes[i] = logHistogramCounters[2 * i]; + logHistogramOutput[i] = logHistogramCounters[2 * i + 1]; + } + return new JoinOperatorInfo(joinType, logHistogramProbes, logHistogramOutput); + } + + @JsonCreator + public JoinOperatorInfo( + @JsonProperty("joinType") JoinType joinType, + @JsonProperty("logHistogramProbes") long[] logHistogramProbes, + @JsonProperty("logHistogramOutput") long[] logHistogramOutput) + { + checkArgument(logHistogramProbes.length == HISTOGRAM_BUCKETS); + checkArgument(logHistogramOutput.length == HISTOGRAM_BUCKETS); + this.joinType = joinType; + this.logHistogramProbes = logHistogramProbes; + this.logHistogramOutput = logHistogramOutput; + } + + @JsonProperty + public JoinType getJoinType() + { + return joinType; + } + + @JsonProperty + public long[] getLogHistogramProbes() + { + return logHistogramProbes; + } + + @JsonProperty + public long[] getLogHistogramOutput() + { + return logHistogramOutput; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("joinType", joinType) + .add("logHistogramProbes", logHistogramProbes) + .add("logHistogramOutput", logHistogramOutput) + .toString(); + } + + @Override + public JoinOperatorInfo mergeWith(JoinOperatorInfo other) + { + checkState(this.joinType.equals(other.joinType), "different join types"); + long[] logHistogramProbes = new long[HISTOGRAM_BUCKETS]; + long[] logHistogramOutput = new long[HISTOGRAM_BUCKETS]; + for (int i = 0; i < HISTOGRAM_BUCKETS; i++) { + logHistogramProbes[i] = this.logHistogramProbes[i] + other.logHistogramProbes[i]; + logHistogramOutput[i] = this.logHistogramOutput[i] + other.logHistogramOutput[i]; + } + return new JoinOperatorInfo(this.joinType, logHistogramProbes, logHistogramOutput); + } + + @Override + public boolean isFinal() + { + return true; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinStatisticsCounter.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinStatisticsCounter.java new file mode 100644 index 0000000000000..f3ae395514485 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinStatisticsCounter.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.operator.LookupJoinOperators.JoinType; + +import java.util.function.Supplier; + +import static com.facebook.presto.operator.JoinOperatorInfo.createJoinOperatorInfo; +import static java.util.Objects.requireNonNull; + +public class JoinStatisticsCounter + implements Supplier +{ + public static final int HISTOGRAM_BUCKETS = 8; + + private static final int INDIVIDUAL_BUCKETS = 4; + + private final JoinType joinType; + // Logarithmic histogram. Regular histogram (or digest) is too expensive, because of memory manipulations. Also, we don't need their guarantees of precision. + // To make it maximally fast by reducing indirections (it will fit in cache L1 anyways) counters are packed in one array. + // Layout (here "bucket" is histogram bucket): + // [2*bucket] count probe positions that produced "bucket" rows on source side, + // [2*bucket + 1] total count of rows that were produces by probe rows in this bucket. + private final long[] logHistogramCounters = new long[HISTOGRAM_BUCKETS * 2]; + + public JoinStatisticsCounter(JoinType joinType) + { + this.joinType = requireNonNull(joinType, "joinType is null"); + } + + public void recordProbe(int numSourcePositions) + { + int bucket; + if (numSourcePositions <= INDIVIDUAL_BUCKETS) { + bucket = numSourcePositions; + } + else if (numSourcePositions <= 10) { + bucket = INDIVIDUAL_BUCKETS + 1; + } + else if (numSourcePositions <= 100) { + bucket = INDIVIDUAL_BUCKETS + 2; + } + else { + bucket = INDIVIDUAL_BUCKETS + 3; + } + logHistogramCounters[2 * bucket]++; + logHistogramCounters[2 * bucket + 1] += numSourcePositions; + } + + @Override + public JoinOperatorInfo get() + { + return createJoinOperatorInfo(joinType, logHistogramCounters); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java index 9976fad52342e..8d484ca1db380 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java @@ -40,6 +40,8 @@ public class LookupJoinOperator private final JoinProbeFactory joinProbeFactory; private final Runnable onClose; + private final JoinStatisticsCounter statisticsCounter; + private final PageBuilder pageBuilder; private final boolean probeOnOuterSide; @@ -50,6 +52,7 @@ public class LookupJoinOperator private boolean closed; private boolean finishing; private long joinPosition = -1; + private int joinSourcePositions = 0; private boolean currentProbePositionProducedRow; @@ -72,6 +75,9 @@ public LookupJoinOperator( this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null"); this.onClose = requireNonNull(onClose, "onClose is null"); + this.statisticsCounter = new JoinStatisticsCounter(joinType); + operatorContext.setInfoSupplier(this.statisticsCounter); + this.pageBuilder = new PageBuilder(types); } @@ -165,6 +171,8 @@ public Page getOutput() if (!advanceProbePosition()) { break; } + statisticsCounter.recordProbe(joinSourcePositions); + joinSourcePositions = 0; } } @@ -215,6 +223,7 @@ private boolean joinCurrentPosition(Counter lookupPositionsConsidered) probe.appendTo(pageBuilder); // write build columns lookupSource.appendTo(joinPosition, pageBuilder, probe.getOutputChannelCount()); + joinSourcePositions++; } // get next position on lookup side for this probe row diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java index 0f44d8a9c5831..9bfe19e9952a4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupJoinOperatorFactory.java @@ -19,17 +19,16 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.SettableFuture; +import com.google.common.util.concurrent.ListenableFuture; import java.util.List; import java.util.Optional; -import java.util.function.Consumer; import static com.facebook.presto.operator.LookupJoinOperators.JoinType.INNER; import static com.facebook.presto.operator.LookupJoinOperators.JoinType.PROBE_OUTER; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.transform; +import static com.google.common.util.concurrent.Futures.transformAsync; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static java.util.Objects.requireNonNull; @@ -46,7 +45,8 @@ public class LookupJoinOperatorFactory private final LookupSourceFactory lookupSourceFactory; private final JoinProbeFactory joinProbeFactory; private final Optional outerOperatorFactory; - private final ReferenceCount referenceCount; + private final ReferenceCount probeReferenceCount; + private final ReferenceCount lookupSourceFactoryUsersCount; private boolean closed; public LookupJoinOperatorFactory(int operatorId, @@ -67,29 +67,29 @@ public LookupJoinOperatorFactory(int operatorId, this.joinType = requireNonNull(joinType, "joinType is null"); this.joinProbeFactory = requireNonNull(joinProbeFactory, "joinProbeFactory is null"); - this.referenceCount = new ReferenceCount(); + probeReferenceCount = new ReferenceCount(); + lookupSourceFactoryUsersCount = new ReferenceCount(); + + // when all probe and build-outer operators finish, destroy the lookup source (freeing the memory) + lookupSourceFactoryUsersCount.getFreeFuture().addListener(lookupSourceFactory::destroy, directExecutor()); + + // Whole probe side is counted as 1 in lookupSourceFactoryUsersCount + probeReferenceCount.getFreeFuture().addListener(lookupSourceFactoryUsersCount::release, directExecutor()); if (joinType == INNER || joinType == PROBE_OUTER) { - // when all join operators finish, destroy the lookup source (freeing the memory) - this.referenceCount.getFreeFuture().addListener(lookupSourceFactory::destroy, directExecutor()); this.outerOperatorFactory = Optional.empty(); } else { - // when all join operators finish, set the outer position future to start the outer operator - SettableFuture outerPositionsFuture = SettableFuture.create(); - this.referenceCount.getFreeFuture().addListener(() -> { - // lookup source may not be finished yet, so add a listener - Futures.addCallback( - lookupSourceFactory.createLookupSource(), - new OnSuccessFutureCallback<>(lookupSource -> outerPositionsFuture.set(lookupSource.getOuterPositionIterator()))); - }, directExecutor()); - - // when output operator finishes, destroy the lookup source - Runnable onOperatorClose = () -> { - // lookup source may not be finished yet, so add a listener, to free the memory - lookupSourceFactory.createLookupSource().addListener(lookupSourceFactory::destroy, directExecutor()); - }; - this.outerOperatorFactory = Optional.of(new LookupOuterOperatorFactory(operatorId, planNodeId, outerPositionsFuture, probeOutputTypes, buildOutputTypes, onOperatorClose)); + // when all join operators finish (and lookup source is ready), set the outer position future to start the outer operator + ListenableFuture lookupSourceAfterProbeFinished = transformAsync(probeReferenceCount.getFreeFuture(), ignored -> lookupSourceFactory.createLookupSource()); + ListenableFuture outerPositionsFuture = transform(lookupSourceAfterProbeFinished, lookupSource -> { + try (LookupSource ignore = lookupSource) { + return lookupSource.getOuterPositionIterator(); + } + }); + + lookupSourceFactoryUsersCount.retain(); + this.outerOperatorFactory = Optional.of(new LookupOuterOperatorFactory(operatorId, planNodeId, outerPositionsFuture, probeOutputTypes, buildOutputTypes, lookupSourceFactoryUsersCount)); } } @@ -105,10 +105,11 @@ private LookupJoinOperatorFactory(LookupJoinOperatorFactory other) joinType = other.joinType; lookupSourceFactory = other.lookupSourceFactory; joinProbeFactory = other.joinProbeFactory; - referenceCount = other.referenceCount; + probeReferenceCount = other.probeReferenceCount; + lookupSourceFactoryUsersCount = other.lookupSourceFactoryUsersCount; outerOperatorFactory = other.outerOperatorFactory; - referenceCount.retain(); + probeReferenceCount.retain(); } public int getOperatorId() @@ -133,14 +134,14 @@ public Operator createOperator(DriverContext driverContext) lookupSourceFactory.setTaskContext(driverContext.getPipelineContext().getTaskContext()); - referenceCount.retain(); + probeReferenceCount.retain(); return new LookupJoinOperator( operatorContext, getTypes(), joinType, lookupSourceFactory.createLookupSource(), joinProbeFactory, - referenceCount::release); + probeReferenceCount::release); } @Override @@ -150,7 +151,7 @@ public void close() return; } closed = true; - referenceCount.release(); + probeReferenceCount.release(); } @Override @@ -164,27 +165,4 @@ public Optional createOuterOperatorFactory() { return outerOperatorFactory; } - - // We use a public class to avoid access problems with the isolated class loaders - public static class OnSuccessFutureCallback - implements FutureCallback - { - private final Consumer onSuccess; - - public OnSuccessFutureCallback(Consumer onSuccess) - { - this.onSuccess = onSuccess; - } - - @Override - public void onSuccess(T result) - { - onSuccess.accept(result); - } - - @Override - public void onFailure(Throwable t) - { - } - } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java index 0d9317a377bfa..78e675223fc4a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/LookupOuterOperator.java @@ -44,7 +44,7 @@ private enum State private final List types; private final List probeOutputTypes; private final List buildOutputTypes; - private final Runnable onOperatorClose; + private final ReferenceCount referenceCount; private State state = State.NOT_CREATED; public LookupOuterOperatorFactory( @@ -53,14 +53,14 @@ public LookupOuterOperatorFactory( ListenableFuture outerPositionsFuture, List probeOutputTypes, List buildOutputTypes, - Runnable onOperatorClose) + ReferenceCount referenceCount) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.outerPositionsFuture = requireNonNull(outerPositionsFuture, "outerPositionsFuture is null"); this.probeOutputTypes = ImmutableList.copyOf(requireNonNull(probeOutputTypes, "probeOutputTypes is null")); this.buildOutputTypes = ImmutableList.copyOf(requireNonNull(buildOutputTypes, "buildOutputTypes is null")); - this.onOperatorClose = requireNonNull(onOperatorClose, "referenceCount is null"); + this.referenceCount = requireNonNull(referenceCount, "referenceCount is null"); this.types = ImmutableList.builder() .addAll(probeOutputTypes) @@ -86,7 +86,8 @@ public Operator createOperator(DriverContext driverContext) state = State.CREATED; OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, LookupOuterOperator.class.getSimpleName()); - return new LookupOuterOperator(operatorContext, outerPositionsFuture, probeOutputTypes, buildOutputTypes, onOperatorClose); + referenceCount.retain(); + return new LookupOuterOperator(operatorContext, outerPositionsFuture, probeOutputTypes, buildOutputTypes, referenceCount::release); } @Override @@ -96,7 +97,7 @@ public void close() return; } state = State.CLOSED; - onOperatorClose.run(); + referenceCount.release(); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java b/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java index aa9ab6776bb92..1342f26ee1563 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/MapUnnester.java @@ -16,8 +16,8 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.MapType; import javax.annotation.Nullable; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/Operator.java b/presto-main/src/main/java/com/facebook/presto/operator/Operator.java index 3e92518ad7268..2142eb2434b8d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/Operator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/Operator.java @@ -32,19 +32,6 @@ public interface Operator */ List getTypes(); - /** - * Notifies the operator that no more pages will be added and the - * operator should finish processing and flush results. This method - * will not be called if the Task is already failed or canceled. - */ - void finish(); - - /** - * Is this operator completely finished processing and no more - * output pages will be produced. - */ - boolean isFinished(); - /** * Returns a future that will be completed when the operator becomes * unblocked. If the operator is not blocked, this method should return @@ -72,6 +59,19 @@ default ListenableFuture isBlocked() */ Page getOutput(); + /** + * Notifies the operator that no more pages will be added and the + * operator should finish processing and flush results. This method + * will not be called if the Task is already failed or canceled. + */ + void finish(); + + /** + * Is this operator completely finished processing and no more + * output pages will be produced. + */ + boolean isFinished(); + /** * This method will always be called before releasing the Operator reference. */ diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java index 7da798861aa69..20dd585272474 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java @@ -40,6 +40,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static io.airlift.units.DataSize.succinctBytes; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -359,6 +360,12 @@ public CounterStat getOutputPositions() return outputPositions; } + @Override + public String toString() + { + return format("%s-%s", operatorType, planNodeId); + } + public OperatorStats getOperatorStats() { Supplier infoSupplier = this.infoSupplier.get(); @@ -502,7 +509,7 @@ public String toString() @ThreadSafe private class OperatorSpillContext - implements SpillContext + implements SpillContext { private final DriverContext driverContext; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java b/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java index ee8823723a2be..11e476e8d0472 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OperatorInfo.java @@ -28,7 +28,8 @@ @JsonSubTypes.Type(value = TableFinishInfo.class, name = "tableFinish"), @JsonSubTypes.Type(value = SplitOperatorInfo.class, name = "splitOperator"), @JsonSubTypes.Type(value = HashCollisionsInfo.class, name = "hashCollisionsInfo"), - @JsonSubTypes.Type(value = PartitionedOutputInfo.class, name = "partitionedOutput") + @JsonSubTypes.Type(value = PartitionedOutputInfo.class, name = "partitionedOutput"), + @JsonSubTypes.Type(value = JoinOperatorInfo.class, name = "joinOperatorInfo") }) public interface OperatorInfo { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OrderByOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/OrderByOperator.java index e35c54f5ca21d..a47421510b2f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OrderByOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OrderByOperator.java @@ -184,6 +184,11 @@ public void addInput(Page page) requireNonNull(page, "page is null"); pageIndex.addPage(page); + + if (!operatorContext.trySetMemoryReservation(pageIndex.getEstimatedSize().toBytes())) { + pageIndex.compact(); + } + operatorContext.setMemoryReservation(pageIndex.getEstimatedSize().toBytes()); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java index 3eb9afe20fff0..cc88e927b29fa 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesHash.java @@ -18,6 +18,7 @@ import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.HashCommon; import it.unimi.dsi.fastutil.longs.LongArrayList; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; @@ -32,6 +33,7 @@ // This implementation assumes arrays used in the hash are always a power of 2 public final class PagesHash { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(PagesHash.class).instanceSize(); private static final DataSize CACHE_SIZE = new DataSize(128, KILOBYTE); private final LongArrayList addresses; private final PagesHashStrategy pagesHashStrategy; @@ -51,7 +53,7 @@ public final class PagesHash public PagesHash( LongArrayList addresses, PagesHashStrategy pagesHashStrategy, - PositionLinks.Builder positionLinks) + PositionLinks.FactoryBuilder positionLinks) { this.addresses = requireNonNull(addresses, "addresses is null"); this.pagesHashStrategy = requireNonNull(pagesHashStrategy, "pagesHashStrategy is null"); @@ -134,7 +136,7 @@ public int getPositionCount() public long getInMemorySizeInBytes() { - return size; + return INSTANCE_SIZE + size; } public long getHashCollisions() @@ -147,12 +149,12 @@ public double getExpectedHashCollisions() return expectedHashCollisions; } - public int getAddressIndex(int position, Page hashChannelsPage, Page allChannelsPage) + public int getAddressIndex(int position, Page hashChannelsPage) { - return getAddressIndex(position, hashChannelsPage, allChannelsPage, pagesHashStrategy.hashRow(position, hashChannelsPage)); + return getAddressIndex(position, hashChannelsPage, pagesHashStrategy.hashRow(position, hashChannelsPage)); } - public int getAddressIndex(int rightPosition, Page hashChannelsPage, Page allChannelsPage, long rawHash) + public int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHash) { int pos = getHashPosition(rawHash, mask); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java index 714db377b7105..90bdc068a55b9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesHashStrategy.java @@ -93,5 +93,13 @@ public interface PagesHashStrategy */ boolean isPositionNull(int blockIndex, int blockPosition); - int compare(int leftBlockIndex, int leftBlockPosition, int rightBlockIndex, int rightBlockPosition); + /** + * Compares sort channel (if applicable) values at the specified positions. + */ + int compareSortChannelPositions(int leftBlockIndex, int leftBlockPosition, int rightBlockIndex, int rightBlockPosition); + + /** + * Checks if sort channel is null at the specified position + */ + boolean isSortChannelPositionNull(int blockIndex, int blockPosition); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java index 2ace70165d2eb..beaa7ce7e7560 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PartitionedOutputOperator.java @@ -56,6 +56,7 @@ public static class PartitionedOutputFactory private final List partitionChannels; private final List> partitionConstants; private final OutputBuffer outputBuffer; + private final boolean replicatesAnyRow; private final OptionalInt nullChannel; private final DataSize maxMemory; @@ -63,6 +64,7 @@ public PartitionedOutputFactory( PartitionFunction partitionFunction, List partitionChannels, List> partitionConstants, + boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, DataSize maxMemory) @@ -70,6 +72,7 @@ public PartitionedOutputFactory( this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null"); this.partitionConstants = requireNonNull(partitionConstants, "partitionConstants is null"); + this.replicatesAnyRow = replicatesAnyRow; this.nullChannel = requireNonNull(nullChannel, "nullChannel is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.maxMemory = requireNonNull(maxMemory, "maxMemory is null"); @@ -91,6 +94,7 @@ public OperatorFactory createOutputOperator( partitionFunction, partitionChannels, partitionConstants, + replicatesAnyRow, nullChannel, outputBuffer, serdeFactory, @@ -108,6 +112,7 @@ public static class PartitionedOutputOperatorFactory private final PartitionFunction partitionFunction; private final List partitionChannels; private final List> partitionConstants; + private final boolean replicatesAnyRow; private final OptionalInt nullChannel; private final OutputBuffer outputBuffer; private final PagesSerdeFactory serdeFactory; @@ -121,6 +126,7 @@ public PartitionedOutputOperatorFactory( PartitionFunction partitionFunction, List partitionChannels, List> partitionConstants, + boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, PagesSerdeFactory serdeFactory, @@ -133,6 +139,7 @@ public PartitionedOutputOperatorFactory( this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null"); this.partitionConstants = requireNonNull(partitionConstants, "partitionConstants is null"); + this.replicatesAnyRow = replicatesAnyRow; this.nullChannel = requireNonNull(nullChannel, "nullChannel is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.serdeFactory = requireNonNull(serdeFactory, "serdeFactory is null"); @@ -156,6 +163,7 @@ public Operator createOperator(DriverContext driverContext) partitionFunction, partitionChannels, partitionConstants, + replicatesAnyRow, nullChannel, outputBuffer, serdeFactory, @@ -178,6 +186,7 @@ public OperatorFactory duplicate() partitionFunction, partitionChannels, partitionConstants, + replicatesAnyRow, nullChannel, outputBuffer, serdeFactory, @@ -198,6 +207,7 @@ public PartitionedOutputOperator( PartitionFunction partitionFunction, List partitionChannels, List> partitionConstants, + boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, PagesSerdeFactory serdeFactory, @@ -209,6 +219,7 @@ public PartitionedOutputOperator( partitionFunction, partitionChannels, partitionConstants, + replicatesAnyRow, nullChannel, outputBuffer, serdeFactory, @@ -298,14 +309,17 @@ private static class PagePartitioner private final List> partitionConstants; private final PagesSerde serde; private final List pageBuilders; + private final boolean replicatesAnyRow; private final OptionalInt nullChannel; // when present, send the position to every partition if this channel is null. private final AtomicLong rowsAdded = new AtomicLong(); private final AtomicLong pagesAdded = new AtomicLong(); + private boolean hasAnyRowBeenReplicated; public PagePartitioner( PartitionFunction partitionFunction, List partitionChannels, List> partitionConstants, + boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, PagesSerdeFactory serdeFactory, @@ -317,6 +331,7 @@ public PagePartitioner( this.partitionConstants = requireNonNull(partitionConstants, "partitionConstants is null").stream() .map(constant -> constant.map(NullableValue::asBlock)) .collect(toImmutableList()); + this.replicatesAnyRow = replicatesAnyRow; this.nullChannel = requireNonNull(nullChannel, "nullChannel is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null"); @@ -351,26 +366,19 @@ public ListenableFuture partitionPage(Page page) Page partitionFunctionArgs = getPartitionFunctionArguments(page); for (int position = 0; position < page.getPositionCount(); position++) { - if (nullChannel.isPresent() && page.getBlock(nullChannel.getAsInt()).isNull(position)) { + boolean shouldReplicate = (replicatesAnyRow && !hasAnyRowBeenReplicated) || + nullChannel.isPresent() && page.getBlock(nullChannel.getAsInt()).isNull(position); + if (shouldReplicate) { for (PageBuilder pageBuilder : pageBuilders) { - pageBuilder.declarePosition(); - - for (int channel = 0; channel < sourceTypes.size(); channel++) { - Type type = sourceTypes.get(channel); - type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel)); - } + appendRow(pageBuilder, page, position); } + hasAnyRowBeenReplicated = true; } else { int partition = partitionFunction.getPartition(partitionFunctionArgs, position); PageBuilder pageBuilder = pageBuilders.get(partition); - pageBuilder.declarePosition(); - - for (int channel = 0; channel < sourceTypes.size(); channel++) { - Type type = sourceTypes.get(channel); - type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel)); - } + appendRow(pageBuilder, page, position); } } return flush(false); @@ -391,6 +399,16 @@ private Page getPartitionFunctionArguments(Page page) return new Page(page.getPositionCount(), blocks); } + private void appendRow(PageBuilder pageBuilder, Page page, int position) + { + pageBuilder.declarePosition(); + + for (int channel = 0; channel < sourceTypes.size(); channel++) { + Type type = sourceTypes.get(channel); + type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel)); + } + } + public ListenableFuture flush(boolean force) { // add all full pages to output buffer diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java b/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java index 1e0a1836a7666..30f26ecea503e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PipelineContext.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map.Entry; +import java.util.Set; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -338,6 +339,7 @@ public PipelineStats getPipelineStats() int queuedPartitionedDrivers = 0; int runningDrivers = 0; int runningPartitionedDrivers = 0; + int blockedDrivers = 0; int completedDrivers = this.completedDrivers.get(); Distribution queuedTime = new Distribution(this.queuedTime); @@ -370,6 +372,9 @@ public PipelineStats getPipelineStats() queuedPartitionedDrivers++; } } + else if (driverStats.isFullyBlocked()) { + blockedDrivers++; + } else { runningDrivers++; if (driverContext.isPartitioned()) { @@ -413,13 +418,15 @@ public PipelineStats getPipelineStats() operatorSummaries.put(entry.getKey(), current); } - ImmutableSet blockedReasons = drivers.stream() + Set runningDriverStats = drivers.stream() .filter(driver -> driver.getEndTime() == null && driver.getStartTime() != null) + .collect(toImmutableSet()); + ImmutableSet blockedReasons = runningDriverStats.stream() .flatMap(driver -> driver.getBlockedReasons().stream()) .collect(toImmutableSet()); - boolean fullyBlocked = drivers.stream() - .filter(driver -> driver.getEndTime() == null && driver.getStartTime() != null) - .allMatch(DriverStats::isFullyBlocked); + + boolean fullyBlocked = !runningDriverStats.isEmpty() && runningDriverStats.stream().allMatch(DriverStats::isFullyBlocked); + return new PipelineStats( pipelineId, @@ -435,6 +442,7 @@ public PipelineStats getPipelineStats() queuedPartitionedDrivers, runningDrivers, runningPartitionedDrivers, + blockedDrivers, completedDrivers, succinctBytes(memoryReservation.get()), @@ -447,7 +455,7 @@ public PipelineStats getPipelineStats() new Duration(totalCpuTime, NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(totalUserTime, NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(totalBlockedTime, NANOSECONDS).convertToMostSuccinctTimeUnit(), - fullyBlocked && (runningDrivers > 0 || runningPartitionedDrivers > 0), + fullyBlocked, blockedReasons, succinctBytes(rawInputDataSize), diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PipelineStats.java b/presto-main/src/main/java/com/facebook/presto/operator/PipelineStats.java index d329d4da3797b..edf966f9dc179 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PipelineStats.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PipelineStats.java @@ -49,6 +49,7 @@ public class PipelineStats private final int queuedPartitionedDrivers; private final int runningDrivers; private final int runningPartitionedDrivers; + private final int blockedDrivers; private final int completedDrivers; private final DataSize memoryReservation; @@ -92,6 +93,7 @@ public PipelineStats( @JsonProperty("queuedPartitionedDrivers") int queuedPartitionedDrivers, @JsonProperty("runningDrivers") int runningDrivers, @JsonProperty("runningPartitionedDrivers") int runningPartitionedDrivers, + @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, @JsonProperty("memoryReservation") DataSize memoryReservation, @@ -138,6 +140,8 @@ public PipelineStats( this.runningDrivers = runningDrivers; checkArgument(runningPartitionedDrivers >= 0, "runningPartitionedDrivers is negative"); this.runningPartitionedDrivers = runningPartitionedDrivers; + checkArgument(blockedDrivers >= 0, "blockedDrivers is negative"); + this.blockedDrivers = blockedDrivers; checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; @@ -239,6 +243,12 @@ public int getRunningPartitionedDrivers() return runningPartitionedDrivers; } + @JsonProperty + public int getBlockedDrivers() + { + return blockedDrivers; + } + @JsonProperty public int getCompletedDrivers() { @@ -367,6 +377,7 @@ public PipelineStats summarize() queuedPartitionedDrivers, runningDrivers, runningPartitionedDrivers, + blockedDrivers, completedDrivers, memoryReservation, systemMemoryReservation, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java index 3baba275ab907..961ae7164e1d0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.Page; import java.util.Optional; -import java.util.function.Function; /** * This class is responsible for iterating over build rows, which have @@ -30,7 +29,7 @@ public interface PositionLinks /** * Initialize iteration over position links. Returns first potentially eligible * join position starting from (and including) position argument. - * + *

* When there are no more position -1 is returned */ int start(int position, int probePosition, Page allProbeChannelsPage); @@ -40,7 +39,7 @@ public interface PositionLinks */ int next(int position, int probePosition, Page allProbeChannelsPage); - interface Builder + interface FactoryBuilder { /** * @return value that should be used in future references to created position links @@ -51,6 +50,25 @@ interface Builder * JoinFilterFunction has to be created and supplied for each thread using PositionLinks * since JoinFilterFunction is not thread safe... */ - Function, PositionLinks> build(); + Factory build(); + + /** + * @return number of linked elements + */ + int size(); + + default boolean isEmpty() + { + return size() == 0; + } + } + + interface Factory + { + /** + * JoinFilterFunction has to be created and supplied for each thread using PositionLinks + * since JoinFilterFunction is not thread safe... + */ + PositionLinks create(Optional joinFilterFunction); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java index 3fb4a621a9d54..847d2292c5004 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java @@ -25,6 +25,8 @@ import com.facebook.presto.spi.RecordPageSource; import com.facebook.presto.spi.UpdatablePageSource; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.split.EmptySplit; +import com.facebook.presto.split.EmptySplitPageSource; import com.facebook.presto.split.PageSourceProvider; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.base.Throwables; @@ -124,6 +126,10 @@ public Supplier> addSplit(Split split) } blocked.set(null); + if (split.getConnectorSplit() instanceof EmptySplit) { + pageSource = new EmptySplitPageSource(); + } + return () -> { if (pageSource instanceof UpdatablePageSource) { return Optional.of((UpdatablePageSource) pageSource); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java index 6f315aa839830..ae53063acf539 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableList; +import org.openjdk.jol.info.ClassLayout; import java.util.List; import java.util.Optional; @@ -31,6 +32,7 @@ public class SimplePagesHashStrategy implements PagesHashStrategy { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SimplePagesHashStrategy.class).instanceSize(); private final List types; private final List outputChannels; private final List> channels; @@ -70,7 +72,7 @@ public int getChannelCount() @Override public long getSizeInBytes() { - return channels.stream() + return INSTANCE_SIZE + channels.stream() .flatMap(List::stream) .mapToLong(Block::getRetainedSizeInBytes) .sum(); @@ -210,9 +212,7 @@ public boolean positionEqualsPositionIgnoreNulls(int leftBlockIndex, int leftPos public boolean isPositionNull(int blockIndex, int blockPosition) { for (int hashChannel : hashChannels) { - List channel = channels.get(hashChannel); - Block block = channel.get(blockIndex); - if (block.isNull(blockPosition)) { + if (isChannelPositionNull(hashChannel, blockIndex, blockPosition)) { return true; } } @@ -220,16 +220,34 @@ public boolean isPositionNull(int blockIndex, int blockPosition) } @Override - public int compare(int leftBlockIndex, int leftBlockPosition, int rightBlockIndex, int rightBlockPosition) + public int compareSortChannelPositions(int leftBlockIndex, int leftBlockPosition, int rightBlockIndex, int rightBlockPosition) { - if (!sortChannel.isPresent()) { - throw new UnsupportedOperationException(); - } - int channel = sortChannel.get().getChannel(); + int channel = getSortChannel(); Block leftBlock = channels.get(channel).get(leftBlockIndex); Block rightBlock = channels.get(channel).get(rightBlockIndex); return types.get(channel).compareTo(leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); } + + @Override + public boolean isSortChannelPositionNull(int blockIndex, int blockPosition) + { + return isChannelPositionNull(getSortChannel(), blockIndex, blockPosition); + } + + private boolean isChannelPositionNull(int channelIndex, int blockIndex, int blockPosition) + { + List channel = channels.get(channelIndex); + Block block = channel.get(blockIndex); + return block.isNull(blockPosition); + } + + private int getSortChannel() + { + if (!sortChannel.isPresent()) { + throw new UnsupportedOperationException(); + } + return sortChannel.get().getChannel(); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java index ca4447bd03c2a..51cdf7077bc5d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java @@ -18,11 +18,14 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntComparator; +import it.unimi.dsi.fastutil.longs.LongArrayList; +import org.openjdk.jol.info.ClassLayout; import java.util.List; import java.util.Optional; -import java.util.function.Function; +import static com.facebook.presto.operator.SyntheticAddress.decodePosition; +import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static com.google.common.base.Preconditions.checkState; import static io.airlift.slice.SizeOf.sizeOf; import static java.util.Objects.requireNonNull; @@ -52,22 +55,42 @@ public final class SortedPositionLinks implements PositionLinks { - public static class Builder implements PositionLinks.Builder + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SortedPositionLinks.class).instanceSize(); + + public static class FactoryBuilder + implements PositionLinks.FactoryBuilder { private final Int2ObjectMap positionLinks; private final int size; private final IntComparator comparator; + private final PagesHashStrategy pagesHashStrategy; + private final LongArrayList addresses; - public Builder(int size, IntComparator comparator) + public FactoryBuilder(int size, PagesHashStrategy pagesHashStrategy, LongArrayList addresses) { - this.comparator = comparator; this.size = size; + this.comparator = new PositionComparator(pagesHashStrategy, addresses); + this.pagesHashStrategy = pagesHashStrategy; + this.addresses = addresses; positionLinks = new Int2ObjectOpenHashMap<>(); } @Override public int link(int from, int to) { + // don't add _from_ row to chain if its sort channel value is null + if (isNull(from)) { + // _to_ row sort channel value might be null. However, in such + // case it will be the only element in the chain, so sorted position + // links enumeration will produce correct results. + return to; + } + + // don't add _to_ row to chain if its sort channel value is null + if (isNull(to)) { + return from; + } + // make sure that from value is the smaller one if (comparator.compare(from, to) > 0) { // _from_ is larger so, just add to current chain _to_ @@ -87,10 +110,18 @@ public int link(int from, int to) } } + private boolean isNull(int position) + { + long pageAddress = addresses.getLong(position); + int blockIndex = decodeSliceIndex(pageAddress); + int blockPosition = decodePosition(pageAddress); + return pagesHashStrategy.isSortChannelPositionNull(blockIndex, blockPosition); + } + @Override - public Function, PositionLinks> build() + public Factory build() { - ArrayPositionLinks.Builder builder = ArrayPositionLinks.builder(size); + ArrayPositionLinks.FactoryBuilder arrayPositionLinksFactoryBuilder = ArrayPositionLinks.builder(size); int[][] sortedPositionLinks = new int[size][]; for (Int2ObjectMap.Entry entry : positionLinks.int2ObjectEntrySet()) { @@ -107,23 +138,29 @@ public Function, PositionLinks> build() // tail to head, so we must add them in descending order to have // smallest element as a head for (int i = positions.size() - 2; i >= 0; i--) { - builder.link(positions.get(i), positions.get(i + 1)); + arrayPositionLinksFactoryBuilder.link(positions.get(i), positions.get(i + 1)); } // add link from starting position to position links chain if (!positions.isEmpty()) { - builder.link(key, positions.get(0)); + arrayPositionLinksFactoryBuilder.link(key, positions.get(0)); } } return lessThanFunction -> { checkState(lessThanFunction.isPresent(), "Using SortedPositionLinks without lessThanFunction"); return new SortedPositionLinks( - builder.build().apply(lessThanFunction), + arrayPositionLinksFactoryBuilder.build().create(Optional.empty()), sortedPositionLinks, lessThanFunction.get()); }; } + + @Override + public int size() + { + return positionLinks.size(); + } } private final PositionLinks positionLinks; @@ -136,12 +173,21 @@ private SortedPositionLinks(PositionLinks positionLinks, int[][] sortedPositionL this.positionLinks = requireNonNull(positionLinks, "positionLinks is null"); this.sortedPositionLinks = requireNonNull(sortedPositionLinks, "sortedPositionLinks is null"); this.lessThanFunction = requireNonNull(lessThanFunction, "lessThanFunction is null"); - this.sizeInBytes = positionLinks.getSizeInBytes() + sizeOf(sortedPositionLinks); + this.sizeInBytes = INSTANCE_SIZE + positionLinks.getSizeInBytes() + sizeOfPositionLinks(sortedPositionLinks); + } + + private long sizeOfPositionLinks(int[][] sortedPositionLinks) + { + long retainedSize = sizeOf(sortedPositionLinks); + for (int[] element : sortedPositionLinks) { + retainedSize += sizeOf(element); + } + return retainedSize; } - public static Builder builder(int size, IntComparator comparator) + public static FactoryBuilder builder(int size, PagesHashStrategy pagesHashStrategy, LongArrayList addresses) { - return new Builder(size, comparator); + return new FactoryBuilder(size, pagesHashStrategy, addresses); } @Override @@ -221,4 +267,37 @@ private boolean applyLessThanFunction(long leftPosition, int rightPosition, Page { return lessThanFunction.filter((int) leftPosition, rightPosition, rightPage); } + + private static class PositionComparator + implements IntComparator + { + private final PagesHashStrategy pagesHashStrategy; + private final LongArrayList addresses; + + PositionComparator(PagesHashStrategy pagesHashStrategy, LongArrayList addresses) + { + this.pagesHashStrategy = pagesHashStrategy; + this.addresses = addresses; + } + + @Override + public int compare(int leftPosition, int rightPosition) + { + long leftPageAddress = addresses.getLong(leftPosition); + int leftBlockIndex = decodeSliceIndex(leftPageAddress); + int leftBlockPosition = decodePosition(leftPageAddress); + + long rightPageAddress = addresses.getLong(rightPosition); + int rightBlockIndex = decodeSliceIndex(rightPageAddress); + int rightBlockPosition = decodePosition(rightPageAddress); + + return pagesHashStrategy.compareSortChannelPositions(leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition); + } + + @Override + public int compare(Integer leftPosition, Integer rightPosition) + { + return compare(leftPosition.intValue(), rightPosition.intValue()); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TableScanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/TableScanOperator.java index 72ec3ec443e42..1732f6b036e01 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TableScanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TableScanOperator.java @@ -20,6 +20,8 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.UpdatablePageSource; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.split.EmptySplit; +import com.facebook.presto.split.EmptySplitPageSource; import com.facebook.presto.split.PageSourceProvider; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.base.Throwables; @@ -159,6 +161,10 @@ public Supplier> addSplit(Split split) blocked.set(null); + if (split.getConnectorSplit() instanceof EmptySplit) { + source = new EmptySplitPageSource(); + } + return () -> { if (source instanceof UpdatablePageSource) { return Optional.of((UpdatablePageSource) source); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TaskContext.java b/presto-main/src/main/java/com/facebook/presto/operator/TaskContext.java index aabfdbe639183..00a1a9d16e443 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TaskContext.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TaskContext.java @@ -31,6 +31,7 @@ import javax.annotation.concurrent.ThreadSafe; import java.util.List; +import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicLong; @@ -282,6 +283,7 @@ public TaskStats getTaskStats() int queuedPartitionedDrivers = 0; int runningDrivers = 0; int runningPartitionedDrivers = 0; + int blockedDrivers = 0; int completedDrivers = 0; long totalScheduledTime = 0; @@ -308,6 +310,7 @@ public TaskStats getTaskStats() queuedPartitionedDrivers += pipeline.getQueuedPartitionedDrivers(); runningDrivers += pipeline.getRunningDrivers(); runningPartitionedDrivers += pipeline.getRunningPartitionedDrivers(); + blockedDrivers += pipeline.getBlockedDrivers(); completedDrivers += pipeline.getCompletedDrivers(); totalScheduledTime += pipeline.getTotalScheduledTime().roundTo(NANOSECONDS); @@ -354,13 +357,15 @@ public TaskStats getTaskStats() lastMemoryReservation = currentMemory; } - boolean fullyBlocked = pipelineStats.stream() - .filter(pipeline -> pipeline.getRunningDrivers() > 0 || pipeline.getRunningPartitionedDrivers() > 0) - .allMatch(PipelineStats::isFullyBlocked); - ImmutableSet blockedReasons = pipelineStats.stream() - .filter(pipeline -> pipeline.getRunningDrivers() > 0 || pipeline.getRunningPartitionedDrivers() > 0) + Set runningPipelineStats = pipelineStats.stream() + .filter(pipeline -> pipeline.getRunningDrivers() > 0 || pipeline.getRunningPartitionedDrivers() > 0 || pipeline.getBlockedDrivers() > 0) + .collect(toImmutableSet()); + ImmutableSet blockedReasons = runningPipelineStats.stream() .flatMap(pipeline -> pipeline.getBlockedReasons().stream()) .collect(toImmutableSet()); + + boolean fullyBlocked = !runningPipelineStats.isEmpty() && runningPipelineStats.stream().allMatch(PipelineStats::isFullyBlocked); + return new TaskStats( taskStateMachine.getCreatedTime(), executionStartTime.get(), @@ -374,6 +379,7 @@ public TaskStats getTaskStats() queuedPartitionedDrivers, runningDrivers, runningPartitionedDrivers, + blockedDrivers, completedDrivers, cumulativeMemory.get(), succinctBytes(memoryReservation.get()), diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TaskStats.java b/presto-main/src/main/java/com/facebook/presto/operator/TaskStats.java index c3177dc23f7c4..2c0245071f5b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TaskStats.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TaskStats.java @@ -48,6 +48,7 @@ public class TaskStats private final int queuedPartitionedDrivers; private final int runningDrivers; private final int runningPartitionedDrivers; + private final int blockedDrivers; private final int completedDrivers; private final double cumulativeMemory; @@ -87,6 +88,7 @@ public TaskStats(DateTime createTime, DateTime endTime) 0, 0, 0, + 0, 0.0, new DataSize(0, BYTE), new DataSize(0, BYTE), @@ -120,6 +122,7 @@ public TaskStats( @JsonProperty("queuedPartitionedDrivers") int queuedPartitionedDrivers, @JsonProperty("runningDrivers") int runningDrivers, @JsonProperty("runningPartitionedDrivers") int runningPartitionedDrivers, + @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, @JsonProperty("cumulativeMemory") double cumulativeMemory, @@ -164,6 +167,9 @@ public TaskStats( checkArgument(runningPartitionedDrivers >= 0, "runningPartitionedDrivers is negative"); this.runningPartitionedDrivers = runningPartitionedDrivers; + checkArgument(blockedDrivers >= 0, "blockedDrivers is negative"); + this.blockedDrivers = blockedDrivers; + checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; @@ -257,6 +263,12 @@ public int getRunningDrivers() return runningDrivers; } + @JsonProperty + public int getBlockedDrivers() + { + return blockedDrivers; + } + @JsonProperty public int getCompletedDrivers() { @@ -386,6 +398,7 @@ public TaskStats summarize() queuedPartitionedDrivers, runningDrivers, runningPartitionedDrivers, + blockedDrivers, completedDrivers, cumulativeMemory, memoryReservation, @@ -420,6 +433,7 @@ public TaskStats summarizeFinal() queuedPartitionedDrivers, runningDrivers, runningPartitionedDrivers, + blockedDrivers, completedDrivers, cumulativeMemory, memoryReservation, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java index cda3a083f1952..fd7eb17072b32 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/UnnestOperator.java @@ -16,10 +16,10 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import java.util.ArrayList; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java index be36b03be0bc6..b426ffa61d46e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/WindowOperator.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; @@ -28,6 +29,7 @@ import java.util.List; import java.util.Optional; +import java.util.function.BiPredicate; import java.util.stream.Stream; import static com.facebook.presto.spi.block.SortOrder.ASC_NULLS_LAST; @@ -447,18 +449,7 @@ private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, checkArgument(page.getPositionCount() > 0, "Must have at least one position"); checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds"); - // Short circuit if the whole page has the same value - if (pagesHashStrategy.rowEqualsRow(startPosition, page, page.getPositionCount() - 1, page)) { - return page.getPositionCount(); - } - - // TODO: do position binary search - int endPosition = startPosition + 1; - while (endPosition < page.getPositionCount() && - pagesHashStrategy.rowEqualsRow(endPosition - 1, page, endPosition, page)) { - endPosition++; - } - return endPosition; + return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowEqualsRow(firstPosition, page, secondPosition, page)); } // Assumes input grouped on relevant pagesHashStrategy columns @@ -467,17 +458,63 @@ private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHa checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position"); checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds"); - // Short circuit if the whole page has the same value - if (pagesIndex.positionEqualsPosition(pagesHashStrategy, startPosition, pagesIndex.getPositionCount() - 1)) { - return pagesIndex.getPositionCount(); + return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionEqualsPosition(pagesHashStrategy, firstPosition, secondPosition)); + } + + /** + * @param startPosition - inclusive + * @param endPosition - exclusive + * @param comparator - returns true if positions given as parameters are equal + * @return the end of the group position exclusive (the position the very next group starts) + */ + @VisibleForTesting + static int findEndPosition(int startPosition, int endPosition, BiPredicate comparator) + { + checkArgument(startPosition >= 0, "startPosition must be greater or equal than zero: %s", startPosition); + checkArgument(endPosition > 0, "endPosition must be greater than zero: %s", endPosition); + checkArgument(startPosition < endPosition, "startPosition must be less than endPosition: %s < %s", startPosition, endPosition); + + int left = startPosition; + int right = endPosition - 1; + for (int i = 0; i < endPosition - startPosition; i++) { + int distance = right - left; + + if (distance == 0) { + return right + 1; + } + + if (distance == 1) { + if (comparator.test(left, right)) { + return right + 1; + } + return right; + } + + int mid = left + distance / 2; + if (comparator.test(left, mid)) { + // explore to the right + left = mid; + } + else { + // explore to the left + right = mid; + } } - // TODO: do position binary search - int endPosition = startPosition + 1; - while ((endPosition < pagesIndex.getPositionCount()) && - pagesIndex.positionEqualsPosition(pagesHashStrategy, endPosition - 1, endPosition)) { - endPosition++; + // hasn't managed to find a solution after N iteration. Probably the input is not sorted. Lets verify it. + for (int first = startPosition; first < endPosition; first++) { + boolean previousPairsWereEqual = true; + for (int second = first + 1; second < endPosition; second++) { + if (!comparator.test(first, second)) { + previousPairsWereEqual = false; + } + else if (!previousPairsWereEqual) { + throw new IllegalArgumentException("The input is not sorted"); + } + } } - return endPosition; + + // the input is sorted, but the algorithm has still failed + throw new IllegalArgumentException("failed to find a group ending"); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java index 793288962bb2d..23e0f86ccfccc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxByNAggregationFunction.java @@ -23,11 +23,10 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -124,7 +123,7 @@ public static void output(ArrayType outputType, MinMaxByNState state, BlockBuild Type elementType = outputType.getElementType(); BlockBuilder arrayBlockBuilder = out.beginBlockEntry(); - BlockBuilder reversedBlockBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), heap.getCapacity()); + BlockBuilder reversedBlockBuilder = elementType.createBlockBuilder(null, heap.getCapacity()); long startSize = heap.getEstimatedSize(); heap.popAll(reversedBlockBuilder); state.addMemoryUsage(heap.getEstimatedSize() - startSize); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java index b5885fe7dd4c4..af718c7de6814 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AbstractMinMaxNAggregationFunction.java @@ -23,11 +23,10 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -148,7 +147,7 @@ public static void output(ArrayType outputType, MinMaxNState state, BlockBuilder Type elementType = outputType.getElementType(); - BlockBuilder reversedBlockBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), heap.getCapacity()); + BlockBuilder reversedBlockBuilder = elementType.createBlockBuilder(null, heap.getCapacity()); long startSize = heap.getEstimatedSize(); heap.popAll(reversedBlockBuilder); state.addMemoryUsage(heap.getEstimatedSize() - startSize); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java index 8d6caf4ed36c4..398e3f4917e51 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/ArrayAggregationFunction.java @@ -23,13 +23,12 @@ import com.facebook.presto.operator.aggregation.state.ArrayAggregationStateSerializer; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -119,7 +118,7 @@ public static void input(Type type, ArrayAggregationState state, Block value, in { BlockBuilder blockBuilder = state.getBlockBuilder(); if (blockBuilder == null) { - blockBuilder = type.createBlockBuilder(new BlockBuilderStatus(), 4); + blockBuilder = type.createBlockBuilder(null, 4); state.setBlockBuilder(blockBuilder); } long startSize = blockBuilder.getRetainedSizeInBytes(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleHistogramAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleHistogramAggregation.java index 2acf52edf2091..e84be451602af 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleHistogramAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/DoubleHistogramAggregation.java @@ -16,7 +16,6 @@ import com.facebook.presto.operator.aggregation.state.DoubleHistogramStateSerializer; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.AggregationFunction; @@ -96,7 +95,7 @@ public static void output(@AggregationState State state, BlockBuilder out) } else { Map value = state.get().getBuckets(); - BlockBuilder blockBuilder = DoubleType.DOUBLE.createBlockBuilder(new BlockBuilderStatus(), value.size() * 2); + BlockBuilder blockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, value.size() * 2); for (Map.Entry entry : value.entrySet()) { DoubleType.DOUBLE.writeDouble(blockBuilder, entry.getKey()); DoubleType.DOUBLE.writeDouble(blockBuilder, entry.getValue()); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java index 2b4e12ad6ad4a..edfcf1efa2a1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/Histogram.java @@ -24,10 +24,10 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -40,6 +40,7 @@ import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -75,16 +76,17 @@ public String getDescription() public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type keyType = boundVariables.getTypeVariable("K"); - Type valueType = BigintType.BIGINT; - return generateAggregation(keyType, valueType); + Type outputType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(BIGINT.getTypeSignature()))); + return generateAggregation(keyType, outputType); } - private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType) + private static InternalAggregationFunction generateAggregation(Type keyType, Type outputType) { DynamicClassLoader classLoader = new DynamicClassLoader(Histogram.class.getClassLoader()); List inputTypes = ImmutableList.of(keyType); - Type outputType = new MapType(keyType, valueType); - HistogramStateSerializer stateSerializer = new HistogramStateSerializer(keyType); + HistogramStateSerializer stateSerializer = new HistogramStateSerializer(keyType, outputType); Type intermediateType = stateSerializer.getSerializedType(); MethodHandle inputFunction = INPUT_FUNCTION.bindTo(keyType); MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(outputType); @@ -154,8 +156,7 @@ public static void output(Type type, HistogramState state, BlockBuilder out) out.appendNull(); } else { - Block block = typedHistogram.serialize(); - type.writeObject(out, block); + typedHistogram.serialize(out); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/KeyValuePairs.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/KeyValuePairs.java index 8070674c556ac..7e20733f3cccf 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/KeyValuePairs.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/KeyValuePairs.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; import org.openjdk.jol.info.ClassLayout; @@ -54,8 +53,8 @@ public KeyValuePairs(Type keyType, Type valueType) { this.keyType = requireNonNull(keyType, "keyType is null"); this.valueType = requireNonNull(valueType, "valueType is null"); - keyBlockBuilder = this.keyType.createBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES, expectedValueSize(keyType, EXPECTED_ENTRY_SIZE)); - valueBlockBuilder = this.valueType.createBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES, expectedValueSize(valueType, EXPECTED_ENTRY_SIZE)); + keyBlockBuilder = this.keyType.createBlockBuilder(null, EXPECTED_ENTRIES, expectedValueSize(keyType, EXPECTED_ENTRY_SIZE)); + valueBlockBuilder = this.valueType.createBlockBuilder(null, EXPECTED_ENTRIES, expectedValueSize(valueType, EXPECTED_ENTRY_SIZE)); hashCapacity = arraySize(EXPECTED_ENTRIES, FILL_RATIO); this.maxFill = calculateMaxFill(hashCapacity); this.hashMask = hashCapacity - 1; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java index b3c1a6c721a1f..f5bc8a320082f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapAggregationFunction.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; @@ -21,12 +20,13 @@ import com.facebook.presto.operator.aggregation.state.KeyValuePairStateSerializer; import com.facebook.presto.operator.aggregation.state.KeyValuePairsState; import com.facebook.presto.operator.aggregation.state.KeyValuePairsStateFactory; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -40,11 +40,9 @@ import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.lang.String.format; public class MapAggregationFunction extends SqlAggregationFunction @@ -75,15 +73,17 @@ public InternalAggregationFunction specialize(BoundVariables boundVariables, int { Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - return generateAggregation(keyType, valueType); + MapType outputType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); + return generateAggregation(keyType, valueType, outputType); } - private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType) + private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType, MapType outputType) { DynamicClassLoader classLoader = new DynamicClassLoader(MapAggregationFunction.class.getClassLoader()); List inputTypes = ImmutableList.of(keyType, valueType); - Type outputType = new MapType(keyType, valueType); - KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(keyType, valueType); + KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(outputType); Type intermediateType = stateSerializer.getSerializedType(); AggregationMetadata metadata = new AggregationMetadata( @@ -118,12 +118,7 @@ public static void input(Type keyType, Type valueType, KeyValuePairsState state, } long startSize = pairs.estimatedInMemorySize(); - try { - pairs.add(key, value, position, position); - } - catch (ExceededMemoryLimitException e) { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("The result of map_agg may not exceed %s", e.getMaxMemory())); - } + pairs.add(key, value, position, position); state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize); } @@ -135,12 +130,7 @@ public static void combine(KeyValuePairsState state, KeyValuePairsState otherSta KeyValuePairs pairs = state.get(); long startSize = pairs.estimatedInMemorySize(); for (int i = 0; i < keys.getPositionCount(); i++) { - try { - pairs.add(keys, values, i, i); - } - catch (ExceededMemoryLimitException e) { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("The result of map_agg may not exceed %s", e.getMaxMemory())); - } + pairs.add(keys, values, i, i); } state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java index 21c1560143715..d644858cb159a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MapUnionAggregation.java @@ -22,9 +22,11 @@ import com.facebook.presto.operator.aggregation.state.KeyValuePairsStateFactory; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -65,15 +67,17 @@ public InternalAggregationFunction specialize(BoundVariables boundVariables, int { Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - return generateAggregation(keyType, valueType); + MapType outputType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); + return generateAggregation(keyType, valueType, outputType); } - private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType) + private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType, MapType outputType) { DynamicClassLoader classLoader = new DynamicClassLoader(MapUnionAggregation.class.getClassLoader()); - Type outputType = new MapType(keyType, valueType); List inputTypes = ImmutableList.of(outputType); - KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(keyType, valueType); + KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(outputType); Type intermediateType = stateSerializer.getSerializedType(); AggregationMetadata metadata = new AggregationMetadata( diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java index aaf20ee460b94..d992282d6a0a7 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultiKeyValuePairs.java @@ -16,16 +16,10 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; -import com.google.common.collect.ImmutableList; import org.openjdk.jol.info.ClassLayout; -import java.util.Optional; - import static com.facebook.presto.type.TypeUtils.expectedValueSize; import static java.util.Objects.requireNonNull; @@ -41,15 +35,12 @@ public class MultiKeyValuePairs private final BlockBuilder valueBlockBuilder; private final Type valueType; - private final RowType serializedRowType; - public MultiKeyValuePairs(Type keyType, Type valueType) { this.keyType = requireNonNull(keyType, "keyType is null"); this.valueType = requireNonNull(valueType, "valueType is null"); - keyBlockBuilder = this.keyType.createBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES, expectedValueSize(keyType, EXPECTED_ENTRY_SIZE)); - valueBlockBuilder = this.valueType.createBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES, expectedValueSize(valueType, EXPECTED_ENTRY_SIZE)); - serializedRowType = new RowType(ImmutableList.of(keyType, valueType), Optional.empty()); + keyBlockBuilder = this.keyType.createBlockBuilder(null, EXPECTED_ENTRIES, expectedValueSize(keyType, EXPECTED_ENTRY_SIZE)); + valueBlockBuilder = this.valueType.createBlockBuilder(null, EXPECTED_ENTRIES, expectedValueSize(valueType, EXPECTED_ENTRY_SIZE)); } public MultiKeyValuePairs(Block serialized, Type keyType, Type valueType) @@ -91,13 +82,13 @@ public void serialize(BlockBuilder out) /** * Serialize as a multimap: map(key, array(value)), each key can be associated with multiple values */ - public Block toMultimapNativeEncoding() + public void toMultimapNativeEncoding(BlockBuilder blockBuilder) { Block keys = keyBlockBuilder.build(); Block values = valueBlockBuilder.build(); // Merge values of the same key into an array - BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(new BlockBuilderStatus(), keys.getPositionCount(), expectedValueSize(keyType, EXPECTED_ENTRY_SIZE)); + BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, keys.getPositionCount(), expectedValueSize(keyType, EXPECTED_ENTRY_SIZE)); ObjectBigArray valueArrayBlockBuilders = new ObjectBigArray<>(); valueArrayBlockBuilders.ensureCapacity(keys.getPositionCount()); TypedSet keySet = new TypedSet(keyType, keys.getPositionCount()); @@ -105,7 +96,7 @@ public Block toMultimapNativeEncoding() if (!keySet.contains(keys, keyValueIndex)) { keySet.add(keys, keyValueIndex); keyType.appendTo(keys, keyValueIndex, distinctKeyBlockBuilder); - BlockBuilder valueArrayBuilder = valueType.createBlockBuilder(new BlockBuilderStatus(), 10, expectedValueSize(valueType, EXPECTED_ENTRY_SIZE)); + BlockBuilder valueArrayBuilder = valueType.createBlockBuilder(null, 10, expectedValueSize(valueType, EXPECTED_ENTRY_SIZE)); valueArrayBlockBuilders.set(keySet.positionOf(keys, keyValueIndex), valueArrayBuilder); } valueType.appendTo(values, keyValueIndex, valueArrayBlockBuilders.get(keySet.positionOf(keys, keyValueIndex))); @@ -114,13 +105,12 @@ public Block toMultimapNativeEncoding() // Write keys and value arrays into one Block Block distinctKeys = distinctKeyBlockBuilder.build(); Type valueArrayType = new ArrayType(valueType); - BlockBuilder multimapBlockBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueArrayType), new BlockBuilderStatus(), distinctKeyBlockBuilder.getPositionCount()); + BlockBuilder multimapBlockBuilder = blockBuilder.beginBlockEntry(); for (int i = 0; i < distinctKeys.getPositionCount(); i++) { keyType.appendTo(distinctKeys, i, multimapBlockBuilder); valueArrayType.writeObject(multimapBlockBuilder, valueArrayBlockBuilders.get(i).build()); } - - return multimapBlockBuilder.build(); + blockBuilder.closeEntry(); } public long estimatedInMemorySize() diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java index 130c5e4376db3..3f9efbf7c3a84 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/MultimapAggregationFunction.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; @@ -21,13 +20,13 @@ import com.facebook.presto.operator.aggregation.state.MultiKeyValuePairStateSerializer; import com.facebook.presto.operator.aggregation.state.MultiKeyValuePairsState; import com.facebook.presto.operator.aggregation.state.MultiKeyValuePairsStateFactory; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; @@ -41,11 +40,9 @@ import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL; import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.lang.String.format; public class MultimapAggregationFunction extends SqlAggregationFunction @@ -76,14 +73,16 @@ public InternalAggregationFunction specialize(BoundVariables boundVariables, int { Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - return generateAggregation(keyType, valueType); + Type outputType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(new ArrayType(valueType).getTypeSignature()))); + return generateAggregation(keyType, valueType, outputType); } - private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType) + private static InternalAggregationFunction generateAggregation(Type keyType, Type valueType, Type outputType) { DynamicClassLoader classLoader = new DynamicClassLoader(MultimapAggregationFunction.class.getClassLoader()); List inputTypes = ImmutableList.of(keyType, valueType); - Type outputType = new MapType(keyType, new ArrayType(valueType)); MultiKeyValuePairStateSerializer stateSerializer = new MultiKeyValuePairStateSerializer(keyType, valueType); Type intermediateType = stateSerializer.getSerializedType(); @@ -119,12 +118,7 @@ public static void input(MultiKeyValuePairsState state, Block key, Block value, } long startSize = pairs.estimatedInMemorySize(); - try { - pairs.add(key, value, position, position); - } - catch (ExceededMemoryLimitException e) { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("The result of map_agg may not exceed %s", e.getMaxMemory())); - } + pairs.add(key, value, position, position); state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize); } @@ -136,12 +130,7 @@ public static void combine(MultiKeyValuePairsState state, MultiKeyValuePairsStat MultiKeyValuePairs pairs = state.get(); long startSize = pairs.estimatedInMemorySize(); for (int i = 0; i < keys.getPositionCount(); i++) { - try { - pairs.add(keys, values, i, i); - } - catch (ExceededMemoryLimitException e) { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("The result of map_agg may not exceed %s", e.getMaxMemory())); - } + pairs.add(keys, values, i, i); } state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize); } @@ -157,9 +146,7 @@ public static void output(MultiKeyValuePairsState state, BlockBuilder out) out.appendNull(); } else { - Block block = pairs.toMultimapNativeEncoding(); - out.writeObject(block); - out.closeEntry(); + pairs.toMultimapNativeEncoding(out); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealHistogramAggregation.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealHistogramAggregation.java index c98d94a1ef847..70686368f37f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealHistogramAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/RealHistogramAggregation.java @@ -15,7 +15,6 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.AggregationFunction; import com.facebook.presto.spi.function.AggregationState; import com.facebook.presto.spi.function.CombineFunction; @@ -63,7 +62,7 @@ public static void output(@AggregationState DoubleHistogramAggregation.State sta } else { Map value = state.get().getBuckets(); - BlockBuilder blockBuilder = REAL.createBlockBuilder(new BlockBuilderStatus(), value.size() * 2); + BlockBuilder blockBuilder = REAL.createBlockBuilder(null, value.size() * 2); for (Map.Entry entry : value.entrySet()) { REAL.writeLong(blockBuilder, floatToRawIntBits(entry.getKey().floatValue())); REAL.writeLong(blockBuilder, floatToRawIntBits(entry.getValue().floatValue())); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHeap.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHeap.java index 7fbe9ce08e0e4..945fd0f70f6eb 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHeap.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHeap.java @@ -15,7 +15,6 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; import org.openjdk.jol.info.ClassLayout; @@ -43,7 +42,7 @@ public TypedHeap(BlockComparator comparator, Type type, int capacity) this.type = type; this.capacity = capacity; this.heapIndex = new int[capacity]; - this.heapBlockBuilder = type.createBlockBuilder(new BlockBuilderStatus(), capacity); + this.heapBlockBuilder = type.createBlockBuilder(null, capacity); } public int getCapacity() @@ -171,7 +170,7 @@ private void compactIfNecessary() if (heapBlockBuilder.getSizeInBytes() < COMPACT_THRESHOLD_BYTES || heapBlockBuilder.getPositionCount() / positionCount < COMPACT_THRESHOLD_RATIO) { return; } - BlockBuilder newHeapBlockBuilder = type.createBlockBuilder(new BlockBuilderStatus(), heapBlockBuilder.getPositionCount()); + BlockBuilder newHeapBlockBuilder = type.createBlockBuilder(null, heapBlockBuilder.getPositionCount()); for (int i = 0; i < positionCount; i++) { type.appendTo(heapBlockBuilder, heapIndex[i], newHeapBlockBuilder); heapIndex[i] = i; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java index 64f8744be0ed5..215f0ef458349 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedHistogram.java @@ -18,11 +18,8 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.type.TypeUtils; -import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import org.openjdk.jol.info.ClassLayout; @@ -62,7 +59,7 @@ public TypedHistogram(Type type, int expectedSize) maxFill = calculateMaxFill(hashCapacity); mask = hashCapacity - 1; - values = this.type.createBlockBuilder(new BlockBuilderStatus(), hashCapacity); + values = this.type.createBlockBuilder(null, hashCapacity); hashPositions = new IntBigArray(-1); hashPositions.ensureCapacity(hashCapacity); counts = new LongBigArray(); @@ -93,15 +90,15 @@ private LongBigArray getCounts() return counts; } - public Block serialize() + public void serialize(BlockBuilder out) { Block valuesBlock = values.build(); - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(type, BIGINT), new BlockBuilderStatus(), valuesBlock.getPositionCount() * 2); + BlockBuilder blockBuilder = out.beginBlockEntry(); for (int i = 0; i < valuesBlock.getPositionCount(); i++) { type.appendTo(valuesBlock, i, blockBuilder); BIGINT.writeLong(blockBuilder, counts.get(i)); } - return blockBuilder.build(); + out.closeEntry(); } public void addAll(TypedHistogram other) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java index c0bbb54ec9b39..1a52fcd6c8e55 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedKeyValueHeap.java @@ -15,10 +15,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import org.openjdk.jol.info.ClassLayout; @@ -53,8 +52,8 @@ public TypedKeyValueHeap(BlockComparator keyComparator, Type keyType, Type value this.valueType = valueType; this.capacity = capacity; this.heapIndex = new int[capacity]; - this.keyBlockBuilder = keyType.createBlockBuilder(new BlockBuilderStatus(), capacity); - this.valueBlockBuilder = valueType.createBlockBuilder(new BlockBuilderStatus(), capacity); + this.keyBlockBuilder = keyType.createBlockBuilder(null, capacity); + this.valueBlockBuilder = valueType.createBlockBuilder(null, capacity); } public static Type getSerializedType(Type keyType, Type valueType) @@ -210,8 +209,8 @@ private void compactIfNecessary() if (keyBlockBuilder.getSizeInBytes() < COMPACT_THRESHOLD_BYTES || keyBlockBuilder.getPositionCount() / positionCount < COMPACT_THRESHOLD_RATIO) { return; } - BlockBuilder newHeapKeyBlockBuilder = keyType.createBlockBuilder(new BlockBuilderStatus(), keyBlockBuilder.getPositionCount()); - BlockBuilder newHeapValueBlockBuilder = valueType.createBlockBuilder(new BlockBuilderStatus(), valueBlockBuilder.getPositionCount()); + BlockBuilder newHeapKeyBlockBuilder = keyType.createBlockBuilder(null, keyBlockBuilder.getPositionCount()); + BlockBuilder newHeapValueBlockBuilder = valueType.createBlockBuilder(null, valueBlockBuilder.getPositionCount()); for (int i = 0; i < positionCount; i++) { keyType.appendTo(keyBlockBuilder, heapIndex[i], newHeapKeyBlockBuilder); valueType.appendTo(valueBlockBuilder, heapIndex[i], newHeapValueBlockBuilder); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java index 1dcb4b71f966e..3ab440d119c31 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -53,7 +52,7 @@ public TypedSet(Type elementType, int expectedSize) { checkArgument(expectedSize >= 0, "expectedSize must not be negative"); this.elementType = requireNonNull(elementType, "elementType must not be null"); - this.elementBlock = elementType.createBlockBuilder(new BlockBuilderStatus(), expectedSize); + this.elementBlock = elementType.createBlockBuilder(null, expectedSize); hashCapacity = arraySize(expectedSize, FILL_RATIO); this.maxFill = calculateMaxFill(hashCapacity); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java index 0dee930316709..0f2e25b332946 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -25,15 +25,15 @@ import com.facebook.presto.spiller.SpillerFactory; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; +import com.google.common.io.Closer; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; +import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.concurrent.ExecutionException; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; @@ -179,32 +179,25 @@ public Iterator buildResult() return hashAggregationBuilder.buildResult(); } - try { - if (shouldMergeWithMemory(getSizeInMemory())) { - return mergeFromDiskAndMemory(); - } - else { - spillToDisk().get(); - return mergeFromDisk(); - } + if (shouldMergeWithMemory(getSizeInMemory())) { + return mergeFromDiskAndMemory(); } - catch (InterruptedException | ExecutionException e) { - Thread.currentThread().interrupt(); - throw Throwables.propagate(e); + else { + getFutureValue(spillToDisk()); + return mergeFromDisk(); } } @Override public void close() { - if (merger.isPresent()) { - merger.get().close(); - } - if (spiller.isPresent()) { - spiller.get().close(); + try (Closer closer = Closer.create()) { + merger.ifPresent(closer::register); + spiller.ifPresent(closer::register); + mergeHashSort.ifPresent(closer::register); } - if (mergeHashSort.isPresent()) { - mergeHashSort.get().close(); + catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateFactory.java index 887b25ac73686..9eb6c4b3eeb0b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateFactory.java @@ -16,6 +16,7 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateFactory; +import org.openjdk.jol.info.ClassLayout; import static java.util.Objects.requireNonNull; @@ -50,6 +51,7 @@ public static class GroupedArrayAggregationState extends AbstractGroupedAccumulatorState implements ArrayAggregationState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedArrayAggregationState.class).instanceSize(); private final ObjectBigArray blockBuilders = new ObjectBigArray(); private long size; @@ -62,7 +64,7 @@ public void ensureCapacity(long size) @Override public long getEstimatedSize() { - return size + blockBuilders.sizeOf(); + return INSTANCE_SIZE + size + blockBuilders.sizeOf(); } @Override @@ -94,17 +96,17 @@ public void setBlockBuilder(BlockBuilder value) public static class SingleArrayAggregationState implements ArrayAggregationState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleArrayAggregationState.class).instanceSize(); private BlockBuilder blockBuilder; @Override public long getEstimatedSize() { - if (blockBuilder == null) { - return 0L; - } - else { - return blockBuilder.getRetainedSizeInBytes(); + long estimatedSize = INSTANCE_SIZE; + if (blockBuilder != null) { + estimatedSize += blockBuilder.getRetainedSizeInBytes(); } + return estimatedSize; } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java index d475a369337fd..fe25ee8a46863 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/ArrayAggregationStateSerializer.java @@ -15,10 +15,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; public class ArrayAggregationStateSerializer implements AccumulatorStateSerializer @@ -55,7 +54,7 @@ public void deserialize(Block block, int index, ArrayAggregationState state) { Block stateBlock = (Block) arrayType.getObject(block, index); int positionCount = stateBlock.getPositionCount(); - BlockBuilder blockBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), positionCount); + BlockBuilder blockBuilder = elementType.createBlockBuilder(null, positionCount); for (int i = 0; i < positionCount; i++) { elementType.appendTo(stateBlock, i, blockBuilder); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileArrayStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileArrayStateFactory.java index 6d1882e0553a8..0e6b14ad98195 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileArrayStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileArrayStateFactory.java @@ -15,11 +15,12 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.spi.function.AccumulatorStateFactory; +import io.airlift.slice.SizeOf; import io.airlift.stats.QuantileDigest; +import org.openjdk.jol.info.ClassLayout; import java.util.List; -import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; import static java.util.Objects.requireNonNull; public class DigestAndPercentileArrayStateFactory @@ -53,6 +54,7 @@ public static class GroupedDigestAndPercentileArrayState extends AbstractGroupedAccumulatorState implements DigestAndPercentileArrayState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedDigestAndPercentileArrayState.class).instanceSize(); private final ObjectBigArray digests = new ObjectBigArray<>(); private final ObjectBigArray> percentilesArray = new ObjectBigArray<>(); private long size; @@ -97,13 +99,14 @@ public void addMemoryUsage(int value) @Override public long getEstimatedSize() { - return size + digests.sizeOf() + percentilesArray.sizeOf(); + return INSTANCE_SIZE + size + digests.sizeOf() + percentilesArray.sizeOf(); } } public static class SingleDigestAndPercentileArrayState implements DigestAndPercentileArrayState { + public static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleDigestAndPercentileArrayState.class).instanceSize(); private QuantileDigest digest; private List percentiles; @@ -140,10 +143,14 @@ public void addMemoryUsage(int value) @Override public long getEstimatedSize() { - if (digest == null) { - return SIZE_OF_DOUBLE; + long estimatedSize = INSTANCE_SIZE; + if (digest != null) { + estimatedSize += digest.estimatedInMemorySizeInBytes(); } - return digest.estimatedInMemorySizeInBytes() + SIZE_OF_DOUBLE; + if (percentiles != null) { + estimatedSize += SizeOf.sizeOfDoubleArray(percentiles.size()); + } + return estimatedSize; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileStateFactory.java index cff30ef16d798..e602d7a9f9dd9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/DigestAndPercentileStateFactory.java @@ -17,8 +17,8 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.spi.function.AccumulatorStateFactory; import io.airlift.stats.QuantileDigest; +import org.openjdk.jol.info.ClassLayout; -import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; import static java.util.Objects.requireNonNull; public class DigestAndPercentileStateFactory @@ -52,6 +52,7 @@ public static class GroupedDigestAndPercentileState extends AbstractGroupedAccumulatorState implements DigestAndPercentileState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedDigestAndPercentileState.class).instanceSize(); private final ObjectBigArray digests = new ObjectBigArray<>(); private final DoubleBigArray percentiles = new DoubleBigArray(); private long size; @@ -97,13 +98,14 @@ public void addMemoryUsage(int value) @Override public long getEstimatedSize() { - return size + digests.sizeOf() + percentiles.sizeOf(); + return INSTANCE_SIZE + size + digests.sizeOf() + percentiles.sizeOf(); } } public static class SingleDigestAndPercentileState implements DigestAndPercentileState { + public static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleDigestAndPercentileState.class).instanceSize(); private QuantileDigest digest; private double percentile; @@ -140,10 +142,11 @@ public void addMemoryUsage(int value) @Override public long getEstimatedSize() { - if (digest == null) { - return SIZE_OF_DOUBLE; + long estimatedSize = INSTANCE_SIZE; + if (digest != null) { + estimatedSize += digest.estimatedInMemorySizeInBytes(); } - return digest.estimatedInMemorySizeInBytes() + SIZE_OF_DOUBLE; + return estimatedSize; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateFactory.java index 25e732f3d75fc..5f10f1bc56dd4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateFactory.java @@ -16,6 +16,7 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.operator.aggregation.TypedHistogram; import com.facebook.presto.spi.function.AccumulatorStateFactory; +import org.openjdk.jol.info.ClassLayout; import static java.util.Objects.requireNonNull; @@ -50,6 +51,7 @@ public static class GroupedState extends AbstractGroupedAccumulatorState implements HistogramState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize(); private final ObjectBigArray typedHistogram = new ObjectBigArray<>(); private long size; @@ -88,13 +90,14 @@ public void addMemoryUsage(long memory) @Override public long getEstimatedSize() { - return size + typedHistogram.sizeOf(); + return INSTANCE_SIZE + size + typedHistogram.sizeOf(); } } public static class SingleState implements HistogramState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize(); private TypedHistogram typedHistogram; @Override @@ -117,10 +120,11 @@ public void addMemoryUsage(long memory) @Override public long getEstimatedSize() { - if (typedHistogram == null) { - return 0; + long estimatedSize = INSTANCE_SIZE; + if (typedHistogram != null) { + estimatedSize += typedHistogram.getEstimatedSize(); } - return typedHistogram.getEstimatedSize(); + return estimatedSize; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java index 4e53f3cdee407..88553458b5c6b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HistogramStateSerializer.java @@ -17,9 +17,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; -import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.MapType; import static com.facebook.presto.operator.aggregation.Histogram.EXPECTED_SIZE_FOR_HASHING; @@ -29,10 +27,10 @@ public class HistogramStateSerializer private final Type type; private final Type serializedType; - public HistogramStateSerializer(Type type) + public HistogramStateSerializer(Type type, Type serializedType) { this.type = type; - this.serializedType = new MapType(type, BigintType.BIGINT); + this.serializedType = serializedType; } @Override @@ -48,7 +46,7 @@ public void serialize(HistogramState state, BlockBuilder out) out.appendNull(); } else { - serializedType.writeObject(out, state.get().serialize()); + state.get().serialize(out); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HyperLogLogStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HyperLogLogStateFactory.java index 75dc0385251fa..ffd79f27c8247 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HyperLogLogStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/HyperLogLogStateFactory.java @@ -16,6 +16,7 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.spi.function.AccumulatorStateFactory; import io.airlift.stats.cardinality.HyperLogLog; +import org.openjdk.jol.info.ClassLayout; import static java.util.Objects.requireNonNull; @@ -50,6 +51,7 @@ public static class GroupedHyperLogLogState extends AbstractGroupedAccumulatorState implements HyperLogLogState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedHyperLogLogState.class).instanceSize(); private final ObjectBigArray hlls = new ObjectBigArray<>(); private long size; @@ -81,13 +83,14 @@ public void addMemoryUsage(int value) @Override public long getEstimatedSize() { - return size + hlls.sizeOf(); + return INSTANCE_SIZE + size + hlls.sizeOf(); } } public static class SingleHyperLogLogState implements HyperLogLogState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleHyperLogLogState.class).instanceSize(); private HyperLogLog hll; @Override @@ -111,10 +114,11 @@ public void addMemoryUsage(int value) @Override public long getEstimatedSize() { - if (hll == null) { - return 0; + long estimatedSize = INSTANCE_SIZE; + if (hll != null) { + estimatedSize += hll.estimatedInMemorySize(); } - return hll.estimatedInMemorySize(); + return estimatedSize; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java index 1cd0fd0c4fb19..23f7b02daa27a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairStateSerializer.java @@ -17,17 +17,17 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.MapType; public class KeyValuePairStateSerializer implements AccumulatorStateSerializer { private final MapType mapType; - public KeyValuePairStateSerializer(Type keyType, Type valueType) + public KeyValuePairStateSerializer(MapType mapType) { - this.mapType = new MapType(keyType, valueType); + this.mapType = mapType; } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairsStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairsStateFactory.java index a9a2baec96f1c..79670cc5da9da 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairsStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/KeyValuePairsStateFactory.java @@ -17,6 +17,7 @@ import com.facebook.presto.operator.aggregation.KeyValuePairs; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.type.Type; +import org.openjdk.jol.info.ClassLayout; import static java.util.Objects.requireNonNull; @@ -60,6 +61,7 @@ public static class GroupedState extends AbstractGroupedAccumulatorState implements KeyValuePairsState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize(); private final Type keyType; private final Type valueType; private final ObjectBigArray pairs = new ObjectBigArray<>(); @@ -118,13 +120,14 @@ public Type getValueType() @Override public long getEstimatedSize() { - return size + pairs.sizeOf(); + return INSTANCE_SIZE + size + pairs.sizeOf(); } } public static class SingleState implements KeyValuePairsState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize(); private final Type keyType; private final Type valueType; private KeyValuePairs pair; @@ -167,10 +170,11 @@ public Type getValueType() @Override public long getEstimatedSize() { - if (pair == null) { - return 0; + long estimatedSize = INSTANCE_SIZE; + if (pair != null) { + estimatedSize += pair.estimatedInMemorySize(); } - return pair.estimatedInMemorySize(); + return estimatedSize; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java index 6fa2eed83d53c..7add618cf40a9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowAndLongStateFactory.java @@ -52,6 +52,7 @@ public static class GroupedLongDecimalWithOverflowAndLongState extends LongDecimalWithOverflowStateFactory.GroupedLongDecimalWithOverflowState implements LongDecimalWithOverflowAndLongState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedLongDecimalWithOverflowAndLongState.class).instanceSize(); private final LongBigArray longs = new LongBigArray(); @Override @@ -76,7 +77,7 @@ public void setLong(long value) @Override public long getEstimatedSize() { - return unscaledDecimals.sizeOf() + overflows.sizeOf() + numberOfElements * SingleLongDecimalWithOverflowAndLongState.SIZE; + return INSTANCE_SIZE + unscaledDecimals.sizeOf() + overflows.sizeOf() + numberOfElements * SingleLongDecimalWithOverflowAndLongState.SIZE; } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java index e5356866daaf3..29eed9ae5828f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java @@ -20,7 +20,6 @@ import org.openjdk.jol.info.ClassLayout; import static com.facebook.presto.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH; -import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static java.util.Objects.requireNonNull; public class LongDecimalWithOverflowStateFactory @@ -54,6 +53,7 @@ public static class GroupedLongDecimalWithOverflowState extends AbstractGroupedAccumulatorState implements LongDecimalWithOverflowState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedLongDecimalWithOverflowState.class).instanceSize(); protected final ObjectBigArray unscaledDecimals = new ObjectBigArray<>(); protected final LongBigArray overflows = new LongBigArray(); protected long numberOfElements; @@ -96,14 +96,15 @@ public void setOverflow(long overflow) @Override public long getEstimatedSize() { - return unscaledDecimals.sizeOf() + overflows.sizeOf() + numberOfElements * SingleLongDecimalWithOverflowState.SIZE; + return INSTANCE_SIZE + unscaledDecimals.sizeOf() + overflows.sizeOf() + numberOfElements * SingleLongDecimalWithOverflowState.SIZE; } } public static class SingleLongDecimalWithOverflowState implements LongDecimalWithOverflowState { - public static final int SIZE = ClassLayout.parseClass(Slice.class).instanceSize() + UNSCALED_DECIMAL_128_SLICE_LENGTH + SIZE_OF_LONG; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleLongDecimalWithOverflowState.class).instanceSize(); + public static final int SIZE = ClassLayout.parseClass(Slice.class).instanceSize() + UNSCALED_DECIMAL_128_SLICE_LENGTH; protected Slice unscaledDecimal; protected long overflow; @@ -136,9 +137,9 @@ public void setOverflow(long overflow) public long getEstimatedSize() { if (getLongDecimal() == null) { - return SIZE_OF_LONG; + return INSTANCE_SIZE; } - return SIZE; + return INSTANCE_SIZE + SIZE; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxByNStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxByNStateFactory.java index 2d793e48c6d7b..7623f079ec6ca 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxByNStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxByNStateFactory.java @@ -16,6 +16,7 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.operator.aggregation.TypedKeyValueHeap; import com.facebook.presto.spi.function.AccumulatorStateFactory; +import org.openjdk.jol.info.ClassLayout; public class MinMaxByNStateFactory implements AccumulatorStateFactory @@ -48,6 +49,7 @@ public static class GroupedMinMaxByNState extends AbstractGroupedAccumulatorState implements MinMaxByNState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedMinMaxByNState.class).instanceSize(); private final ObjectBigArray heaps = new ObjectBigArray<>(); private long size; @@ -60,7 +62,7 @@ public void ensureCapacity(long size) @Override public long getEstimatedSize() { - return heaps.sizeOf() + size; + return INSTANCE_SIZE + heaps.sizeOf() + size; } @Override @@ -90,15 +92,17 @@ public void addMemoryUsage(long memory) public static class SingleMinMaxByNState implements MinMaxByNState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMinMaxByNState.class).instanceSize(); private TypedKeyValueHeap typedKeyValueHeap; @Override public long getEstimatedSize() { - if (typedKeyValueHeap == null) { - return 0; + long estimatedSize = INSTANCE_SIZE; + if (typedKeyValueHeap != null) { + estimatedSize += typedKeyValueHeap.getEstimatedSize(); } - return typedKeyValueHeap.getEstimatedSize(); + return estimatedSize; } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateFactory.java index bbf7d57684fd5..5cd21282cb640 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateFactory.java @@ -16,6 +16,7 @@ import com.facebook.presto.array.ObjectBigArray; import com.facebook.presto.operator.aggregation.TypedHeap; import com.facebook.presto.spi.function.AccumulatorStateFactory; +import org.openjdk.jol.info.ClassLayout; public class MinMaxNStateFactory implements AccumulatorStateFactory @@ -48,6 +49,7 @@ public static class GroupedMinMaxNState extends AbstractGroupedAccumulatorState implements MinMaxNState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedMinMaxNState.class).instanceSize(); private final ObjectBigArray heaps = new ObjectBigArray<>(); private long size; @@ -60,7 +62,7 @@ public void ensureCapacity(long size) @Override public long getEstimatedSize() { - return heaps.sizeOf() + size; + return INSTANCE_SIZE + heaps.sizeOf() + size; } @Override @@ -90,15 +92,17 @@ public void addMemoryUsage(long memory) public static class SingleMinMaxNState implements MinMaxNState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMinMaxNState.class).instanceSize(); private TypedHeap typedHeap; @Override public long getEstimatedSize() { - if (typedHeap == null) { - return 0; + long estimatedSize = INSTANCE_SIZE; + if (typedHeap != null) { + estimatedSize += typedHeap.getEstimatedSize(); } - return typedHeap.getEstimatedSize(); + return estimatedSize; } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java index 1c8068b681715..3115fd6806d2b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MinMaxNStateSerializer.java @@ -18,9 +18,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.util.Optional; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java index d523333f2e8f9..a10d05e22079b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairStateSerializer.java @@ -17,9 +17,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.util.Optional; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairsStateFactory.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairsStateFactory.java index fc1bae6835b69..559ff72eb0d63 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairsStateFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/MultiKeyValuePairsStateFactory.java @@ -17,6 +17,7 @@ import com.facebook.presto.operator.aggregation.MultiKeyValuePairs; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.type.Type; +import org.openjdk.jol.info.ClassLayout; import static java.util.Objects.requireNonNull; @@ -60,6 +61,7 @@ public static class GroupedState extends AbstractGroupedAccumulatorState implements MultiKeyValuePairsState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedState.class).instanceSize(); private final Type keyType; private final Type valueType; private final ObjectBigArray pairs = new ObjectBigArray<>(); @@ -118,13 +120,14 @@ public Type getValueType() @Override public long getEstimatedSize() { - return size + pairs.sizeOf(); + return INSTANCE_SIZE + size + pairs.sizeOf(); } } public static class SingleState implements MultiKeyValuePairsState { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleState.class).instanceSize(); private final Type keyType; private final Type valueType; private MultiKeyValuePairs pair; @@ -167,10 +170,11 @@ public Type getValueType() @Override public long getEstimatedSize() { - if (pair == null) { - return 0; + long estimatedSize = INSTANCE_SIZE; + if (pair != null) { + estimatedSize += pair.estimatedInMemorySize(); } - return pair.estimatedInMemorySize(); + return estimatedSize; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java index 94c722b7f678a..8c0efe6db7604 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/state/StateCompiler.java @@ -35,10 +35,10 @@ import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.AccumulatorStateSerializer; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; -import com.facebook.presto.type.RowType; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -72,11 +72,11 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantBoolean; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantClass; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; -import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantLong; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNumber; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.defaultValue; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -403,21 +403,12 @@ private static Class generateSingleStateClass(Class clazz, M type(Object.class), type(clazz)); - // Store class size in static field - FieldDefinition classSize = definition.declareField(a(PRIVATE, STATIC, FINAL), "CLASS_SIZE", long.class); - definition.getClassInitializer() - .getBody() - .comment("CLASS_SIZE = ClassLayout.parseClass(%s.class).instanceSize()", definition.getName()) - .push(definition.getType()) - .invokeStatic(ClassLayout.class, "parseClass", ClassLayout.class, Class.class) - .invokeVirtual(ClassLayout.class, "instanceSize", int.class) - .intToLong() - .putStaticField(classSize); + FieldDefinition instanceSize = generateInstanceSize(definition); // Add getter for class size definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class)) .getBody() - .getStaticField(classSize) + .getStaticField(instanceSize) .retLong(); // Generate constructor @@ -439,6 +430,21 @@ private static Class generateSingleStateClass(Class clazz, M return defineClass(definition, clazz, classLoader); } + private static FieldDefinition generateInstanceSize(ClassDefinition definition) + { + // Store instance size in static field + FieldDefinition instanceSize = definition.declareField(a(PRIVATE, STATIC, FINAL), "INSTANCE_SIZE", long.class); + definition.getClassInitializer() + .getBody() + .comment("INSTANCE_SIZE = ClassLayout.parseClass(%s.class).instanceSize()", definition.getName()) + .push(definition.getType()) + .invokeStatic(ClassLayout.class, "parseClass", ClassLayout.class, Class.class) + .invokeVirtual(ClassLayout.class, "instanceSize", int.class) + .intToLong() + .putStaticField(instanceSize); + return instanceSize; + } + private static Class generateGroupedStateClass(Class clazz, Map fieldTypes, DynamicClassLoader classLoader) { ClassDefinition definition = new ClassDefinition( @@ -448,6 +454,8 @@ private static Class generateGroupedStateClass(Class clazz, type(clazz), type(GroupedAccumulator.class)); + FieldDefinition instanceSize = generateInstanceSize(definition); + List fields = enumerateFields(clazz, fieldTypes); // Create constructor @@ -474,8 +482,8 @@ private static Class generateGroupedStateClass(Class clazz, Variable size = getEstimatedSize.getScope().declareVariable(long.class, "size"); - // initialize size to 0L - body.append(size.set(constantLong(0))); + // initialize size to the size of the instance + body.append(size.set(getStatic(instanceSize))); // add field to size for (FieldDefinition field : fieldDefinitions) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java index 2844a942ab512..5bb1d25208c67 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageFilter.java @@ -64,10 +64,12 @@ public SelectedPositions filter(ConnectorSession session, Page page) if (block instanceof RunLengthEncodedBlock) { Block value = ((RunLengthEncodedBlock) block).getValue(); - Optional selectedDictionaryPositions = processDictionary(session, value); - // single value block is always considered effective - verify(selectedDictionaryPositions.isPresent()); - return SelectedPositions.positionsRange(0, selectedDictionaryPositions.get()[0] ? page.getPositionCount() : 0); + Optional selectedPosition = processDictionary(session, value); + // single value block is always considered effective, but the processing could have thrown + // in that case we fallback and process again so the correct error message sent + if (selectedPosition.isPresent()) { + return SelectedPositions.positionsRange(0, selectedPosition.get()[0] ? page.getPositionCount() : 0); + } } if (block instanceof DictionaryBlock) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java index be740bd85d1a3..c9d39bcd6bcbb 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java @@ -75,9 +75,11 @@ public Block project(ConnectorSession session, Page page, SelectedPositions sele if (block instanceof RunLengthEncodedBlock) { Block value = ((RunLengthEncodedBlock) block).getValue(); Optional projectedValue = processDictionary(session, value); - // single value block is always considered effective - verify(projectedValue.isPresent()); - return new RunLengthEncodedBlock(projectedValue.get(), selectedPositions.size()); + // single value block is always considered effective, but the processing could have thrown + // in that case we fallback and process again so the correct error message sent + if (projectedValue.isPresent()) { + return new RunLengthEncodedBlock(projectedValue.get(), selectedPositions.size()); + } } if (block instanceof DictionaryBlock) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedCursorProcessor.java b/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedCursorProcessor.java index 20ec27e6467aa..d6605b06d2d3a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedCursorProcessor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedCursorProcessor.java @@ -25,7 +25,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolToInputParameterRewriter; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.ImmutableMap; import javax.annotation.Nullable; @@ -87,7 +87,7 @@ private static ExpressionInterpreter getExpressionInterpreter( parameterTypes.put(parameter, type); } - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList()); + Map, Type> expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList()); return expressionInterpreter(rewritten, metadata, session, expressionTypes); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageFilter.java b/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageFilter.java index 0eddf1c7bd101..2df00d66ba7d0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageFilter.java @@ -24,7 +24,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolToInputParameterRewriter; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.ImmutableMap; import javax.annotation.concurrent.NotThreadSafe; @@ -65,8 +65,7 @@ public InterpretedPageFilter( Type type = inputTypes.get(parameter); parameterTypes.put(parameter, type); } - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList()); - + Map, Type> expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList()); this.evaluator = ExpressionInterpreter.expressionInterpreter(rewritten, metadata, session, expressionTypes); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageProjection.java b/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageProjection.java index c756375fe0937..a6c55e8555bbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageProjection.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/InterpretedPageProjection.java @@ -27,7 +27,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolToInputParameterRewriter; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -65,8 +65,7 @@ public InterpretedPageProjection( Type type = inputTypes.get(parameter); parameterTypes.put(parameter, type); } - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList()); - + Map, Type> expressionTypes = getExpressionTypesFromInput(session, metadata, sqlParser, parameterTypes.build(), rewritten, emptyList()); this.evaluator = ExpressionInterpreter.expressionInterpreter(rewritten, metadata, session, expressionTypes); blockBuilder = evaluator.getType().createBlockBuilder(new BlockBuilderStatus(), 1); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/project/PageFieldsToInputParametersRewriter.java b/presto-main/src/main/java/com/facebook/presto/operator/project/PageFieldsToInputParametersRewriter.java index 9d109eff657d7..8b7da048b841a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/project/PageFieldsToInputParametersRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/project/PageFieldsToInputParametersRewriter.java @@ -47,7 +47,7 @@ public static Result rewritePageFieldsToInputParameters(RowExpression expression } private static class Visitor - implements RowExpressionVisitor + implements RowExpressionVisitor { private final Map fieldToParameter = new HashMap<>(); private final List inputChannels = new ArrayList<>(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ApplyFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ApplyFunction.java index 711f48ce67198..2dc7e72c3d547 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ApplyFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ApplyFunction.java @@ -20,10 +20,12 @@ import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.gen.lambda.UnaryFunctionInterface; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; @@ -38,7 +40,7 @@ public final class ApplyFunction { public static final ApplyFunction APPLY_FUNCTION = new ApplyFunction(); - private static final MethodHandle METHOD_HANDLE = methodHandle(ApplyFunction.class, "apply", Object.class, MethodHandle.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(ApplyFunction.class, "apply", Object.class, UnaryFunctionInterface.class); private ApplyFunction() { @@ -78,6 +80,8 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in return new ScalarFunctionImplementation( true, ImmutableList.of(true, false), + ImmutableList.of(false, false), + ImmutableList.of(Optional.empty(), Optional.of(UnaryFunctionInterface.class)), METHOD_HANDLE.asType( METHOD_HANDLE.type() .changeReturnType(wrap(returnType.getJavaType())) @@ -85,10 +89,10 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in isDeterministic()); } - public static Object apply(Object input, MethodHandle function) + public static Object apply(Object input, UnaryFunctionInterface function) { try { - return function.invoke(input); + return function.apply(input); } catch (Throwable throwable) { throw Throwables.propagate(throwable); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java index 3e9ca40e5d406..8746f39882049 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayConcatFunction.java @@ -98,6 +98,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in false, nCopies(arity, false), nCopies(arity, false), + nCopies(arity, Optional.empty()), methodHandleAndConstructor.getMethodHandle(), Optional.of(methodHandleAndConstructor.getConstructor()), isDeterministic()); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java index 0e6db85df20a0..bce7bb8347c84 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayEqualOperator.java @@ -27,8 +27,8 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.EQUAL; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; @ScalarOperator(EQUAL) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java index 4ef010e4fdb02..8c44e317ed299 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFilterFunction.java @@ -22,11 +22,10 @@ import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.function.TypeParameterSpecialization; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.gen.lambda.LambdaFunctionInterface; import com.google.common.base.Throwables; import io.airlift.slice.Slice; -import java.lang.invoke.MethodHandle; - import static java.lang.Boolean.TRUE; @Description("return array containing elements that match the given predicate") @@ -38,9 +37,10 @@ private ArrayFilterFunction() {} @TypeParameter("T") @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) @SqlType("array(T)") - public static Block filterLong(@TypeParameter("T") Type elementType, + public static Block filterLong( + @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") MethodHandle function) + @SqlType("function(T, boolean)") FilterLongLambda function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), positionCount); @@ -52,7 +52,7 @@ public static Block filterLong(@TypeParameter("T") Type elementType, Boolean keep; try { - keep = (Boolean) function.invokeExact(input); + keep = function.apply(input); } catch (Throwable throwable) { throw Throwables.propagate(throwable); @@ -67,9 +67,10 @@ public static Block filterLong(@TypeParameter("T") Type elementType, @TypeParameter("T") @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) @SqlType("array(T)") - public static Block filterDouble(@TypeParameter("T") Type elementType, + public static Block filterDouble( + @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") MethodHandle function) + @SqlType("function(T, boolean)") FilterDoubleLambda function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), positionCount); @@ -81,7 +82,7 @@ public static Block filterDouble(@TypeParameter("T") Type elementType, Boolean keep; try { - keep = (Boolean) function.invokeExact(input); + keep = function.apply(input); } catch (Throwable throwable) { throw Throwables.propagate(throwable); @@ -96,9 +97,10 @@ public static Block filterDouble(@TypeParameter("T") Type elementType, @TypeParameter("T") @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) @SqlType("array(T)") - public static Block filterBoolean(@TypeParameter("T") Type elementType, + public static Block filterBoolean( + @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") MethodHandle function) + @SqlType("function(T, boolean)") FilterBooleanLambda function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), positionCount); @@ -110,7 +112,7 @@ public static Block filterBoolean(@TypeParameter("T") Type elementType, Boolean keep; try { - keep = (Boolean) function.invokeExact(input); + keep = function.apply(input); } catch (Throwable throwable) { throw Throwables.propagate(throwable); @@ -125,9 +127,10 @@ public static Block filterBoolean(@TypeParameter("T") Type elementType, @TypeParameter("T") @TypeParameterSpecialization(name = "T", nativeContainerType = Slice.class) @SqlType("array(T)") - public static Block filterSlice(@TypeParameter("T") Type elementType, + public static Block filterSlice( + @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") MethodHandle function) + @SqlType("function(T, boolean)") FilterSliceLambda function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), positionCount); @@ -139,7 +142,7 @@ public static Block filterSlice(@TypeParameter("T") Type elementType, Boolean keep; try { - keep = (Boolean) function.invokeExact(input); + keep = function.apply(input); } catch (Throwable throwable) { throw Throwables.propagate(throwable); @@ -154,9 +157,10 @@ public static Block filterSlice(@TypeParameter("T") Type elementType, @TypeParameter("T") @TypeParameterSpecialization(name = "T", nativeContainerType = Block.class) @SqlType("array(T)") - public static Block filterBlock(@TypeParameter("T") Type elementType, + public static Block filterBlock( + @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") MethodHandle function) + @SqlType("function(T, boolean)") FilterBlockLambda function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), positionCount); @@ -168,7 +172,7 @@ public static Block filterBlock(@TypeParameter("T") Type elementType, Boolean keep; try { - keep = (Boolean) function.invokeExact(input); + keep = function.apply(input); } catch (Throwable throwable) { throw Throwables.propagate(throwable); @@ -183,16 +187,17 @@ public static Block filterBlock(@TypeParameter("T") Type elementType, @TypeParameter("T") @TypeParameterSpecialization(name = "T", nativeContainerType = void.class) @SqlType("array(T)") - public static Block filterVoid(@TypeParameter("T") Type elementType, + public static Block filterVoid( + @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") MethodHandle function) + @SqlType("function(T, boolean)") FilterVoidLambda function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), positionCount); for (int position = 0; position < positionCount; position++) { Boolean keep; try { - keep = (Boolean) function.invokeExact(null); + keep = function.apply(null); } catch (Throwable throwable) { throw Throwables.propagate(throwable); @@ -203,4 +208,40 @@ public static Block filterVoid(@TypeParameter("T") Type elementType, } return resultBuilder.build(); } + + @FunctionalInterface + public interface FilterLongLambda extends LambdaFunctionInterface + { + Boolean apply(Long x); + } + + @FunctionalInterface + public interface FilterDoubleLambda extends LambdaFunctionInterface + { + Boolean apply(Double x); + } + + @FunctionalInterface + public interface FilterBooleanLambda extends LambdaFunctionInterface + { + Boolean apply(Boolean x); + } + + @FunctionalInterface + public interface FilterSliceLambda extends LambdaFunctionInterface + { + Boolean apply(Slice x); + } + + @FunctionalInterface + public interface FilterBlockLambda extends LambdaFunctionInterface + { + Boolean apply(Block x); + } + + @FunctionalInterface + public interface FilterVoidLambda extends LambdaFunctionInterface + { + Boolean apply(Void x); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java index b25fc849d2095..06ad23192ff61 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFlattenFunction.java @@ -32,6 +32,7 @@ import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.util.Reflection.methodHandle; +import static java.lang.Math.toIntExact; public class ArrayFlattenFunction extends SqlScalarFunction @@ -84,7 +85,7 @@ public static Block flatten(Type type, Type arrayType, Block array) return type.createBlockBuilder(new BlockBuilderStatus(), 0).build(); } - BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), array.getPositionCount(), array.getSizeInBytes() / array.getPositionCount()); + BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), array.getPositionCount(), toIntExact(array.getSizeInBytes() / array.getPositionCount())); for (int i = 0; i < array.getPositionCount(); i++) { if (!array.isNull(i)) { Block subArray = (Block) arrayType.getObject(array, i); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java index 71465038997b0..2ecb5903c320d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFunctions.java @@ -18,7 +18,7 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import static com.facebook.presto.type.UnknownType.UNKNOWN; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java index 9fb1ab3bab83c..ac9670db01bd1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOperator.java @@ -28,8 +28,8 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; @ScalarOperator(GREATER_THAN) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java index 0cb9b023e8878..c5fdfa320a9a5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayGreaterThanOrEqualOperator.java @@ -28,8 +28,8 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN; import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; @ScalarOperator(GREATER_THAN_OR_EQUAL) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java index 01bd53a45c1cf..aef060e9cca9b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java @@ -45,6 +45,7 @@ import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.Reflection.methodHandle; +import static java.lang.Math.toIntExact; import static java.lang.String.format; public final class ArrayJoin @@ -202,7 +203,7 @@ public static Slice arrayJoin(MethodHandle castFunction, ConnectorSession sessio { int numElements = arrayBlock.getPositionCount(); - DynamicSliceOutput sliceOutput = new DynamicSliceOutput(arrayBlock.getSizeInBytes() + delimiter.length() * arrayBlock.getPositionCount()); + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(toIntExact(arrayBlock.getSizeInBytes() + delimiter.length() * arrayBlock.getPositionCount())); for (int i = 0; i < numElements; i++) { if (arrayBlock.isNull(i)) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java index b966d16121a7b..f245a0b436596 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOperator.java @@ -28,8 +28,8 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; @ScalarOperator(LESS_THAN) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java index 104f1b3cb80bf..6be0c164710f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayLessThanOrEqualOperator.java @@ -29,8 +29,8 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; @ScalarOperator(LESS_THAN_OR_EQUAL) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReduceFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReduceFunction.java index 8bf5b1b6128fe..f80e4db63589f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReduceFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayReduceFunction.java @@ -21,11 +21,14 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface; +import com.facebook.presto.sql.gen.lambda.UnaryFunctionInterface; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; @@ -37,7 +40,7 @@ public final class ArrayReduceFunction { public static final ArrayReduceFunction ARRAY_REDUCE_FUNCTION = new ArrayReduceFunction(); - private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayReduceFunction.class, "reduce", Type.class, Block.class, Object.class, MethodHandle.class, MethodHandle.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayReduceFunction.class, "reduce", Type.class, Block.class, Object.class, BinaryFunctionInterface.class, UnaryFunctionInterface.class); private ArrayReduceFunction() { @@ -79,6 +82,8 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in return new ScalarFunctionImplementation( true, ImmutableList.of(false, true, false, false), + ImmutableList.of(false, false, false, false), + ImmutableList.of(Optional.empty(), Optional.empty(), Optional.of(BinaryFunctionInterface.class), Optional.of(UnaryFunctionInterface.class)), methodHandle.asType( methodHandle.type() .changeParameterType(1, Primitives.wrap(intermediateType.getJavaType())) @@ -90,22 +95,22 @@ public static Object reduce( Type inputType, Block block, Object initialIntermediateValue, - MethodHandle inputFunction, - MethodHandle outputFunction) + BinaryFunctionInterface inputFunction, + UnaryFunctionInterface outputFunction) { int positionCount = block.getPositionCount(); Object intermediateValue = initialIntermediateValue; for (int position = 0; position < positionCount; position++) { Object input = readNativeValue(inputType, block, position); try { - intermediateValue = inputFunction.invoke(intermediateValue, input); + intermediateValue = inputFunction.apply(intermediateValue, input); } catch (Throwable throwable) { throw Throwables.propagate(throwable); } } try { - return outputFunction.invoke(intermediateValue); + return outputFunction.apply(intermediateValue); } catch (Throwable throwable) { throw Throwables.propagate(throwable); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java index b6bb34fb64b2a..3af984b239ad3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayTransformFunction.java @@ -30,14 +30,14 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.gen.CallSiteBinder; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.sql.gen.lambda.UnaryFunctionInterface; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; -import java.lang.invoke.MethodHandle; import java.util.List; import java.util.Optional; @@ -108,7 +108,8 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in false, ImmutableList.of(false, false), ImmutableList.of(false, false), - methodHandle(generatedClass, "transform", PageBuilder.class, Block.class, MethodHandle.class), + ImmutableList.of(Optional.empty(), Optional.of(UnaryFunctionInterface.class)), + methodHandle(generatedClass, "transform", PageBuilder.class, Block.class, UnaryFunctionInterface.class), Optional.of(methodHandle(generatedClass, "createPageBuilder")), isDeterministic()); } @@ -133,7 +134,7 @@ private static Class generateTransform(Type inputType, Type outputType) // define transform method Parameter pageBuilder = arg("pageBuilder", PageBuilder.class); Parameter block = arg("block", Block.class); - Parameter function = arg("function", MethodHandle.class); + Parameter function = arg("function", UnaryFunctionInterface.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), @@ -188,7 +189,7 @@ private static Class generateTransform(Type inputType, Type outputType) .update(incrementVariable(position, (byte) 1)) .body(new BytecodeBlock() .append(loadInputElement) - .append(outputElement.set(function.invoke("invokeExact", outputJavaType, inputElement))) + .append(outputElement.set(function.invoke("apply", Object.class, inputElement.cast(Object.class)).cast(outputJavaType))) .append(writeOutputElement))); body.append(pageBuilder.invoke("declarePositions", void.class, positionCount)); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java index b3824f22da9e1..8cfea16118416 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java @@ -18,15 +18,24 @@ import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.LiteralParameter; +import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import io.airlift.slice.SliceUtf8; +import io.airlift.slice.Slices; + +import java.util.ArrayList; +import java.util.List; import static com.facebook.presto.spi.type.Chars.padSpaces; -import static com.facebook.presto.spi.type.Chars.trimSpaces; import static com.facebook.presto.spi.type.Chars.trimSpacesAndTruncateToLength; import static com.facebook.presto.spi.type.Varchars.truncateToLength; -import static io.airlift.slice.SliceUtf8.countCodePoints; -import static io.airlift.slice.SliceUtf8.offsetOfCodePoint; -import static io.airlift.slice.Slices.EMPTY_SLICE; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.slice.SliceUtf8.getCodePointAt; +import static io.airlift.slice.SliceUtf8.lengthOfCodePoint; +import static io.airlift.slice.SliceUtf8.setCodePointAt; +import static java.lang.Math.toIntExact; +import static java.util.Collections.nCopies; public final class CharacterStringCasts { @@ -81,24 +90,74 @@ public static Slice charToVarcharCast(@LiteralParameter("x") Long x, @LiteralPar @ScalarOperator(OperatorType.SATURATED_FLOOR_CAST) @SqlType("char(y)") @LiteralParameters({"x", "y"}) - // This function returns Char(y) value that is smaller than original Varchar(x) value. However, it is not necessarily the largest - // Char(y) value that is smaller than the original Varchar(x) value. This is fine though for usage in TupleDomainTranslator. - public static Slice varcharToCharSaturatedFloorCast(@LiteralParameter("y") Long y, @SqlType("varchar(x)") Slice slice) + public static Slice varcharToCharSaturatedFloorCast(@LiteralParameter("y") long y, @SqlType("varchar(x)") Slice slice) { - Slice trimmedSlice = trimSpaces(slice); - int trimmedTextLength = countCodePoints(trimmedSlice); - int numberOfTrailingSpaces = slice.length() - trimmedSlice.length(); + List codePoints = new ArrayList<>(toCodePoints(slice)); // if Varchar(x) value length (including spaces) is greater than y, we can just truncate it - if (trimmedTextLength + numberOfTrailingSpaces >= y) { - return truncateToLength(trimmedSlice, y.intValue()); + if (codePoints.size() >= y) { + // char(y) slice representation doesn't contain trailing spaces + codePoints = trimTrailing(codePoints, ' '); + List codePointsTruncated = codePoints.stream() + .limit(y) + .collect(toImmutableList()); + return codePointsToSliceUtf8(codePointsTruncated); } - if (trimmedTextLength == 0) { - return EMPTY_SLICE; + + /* + * Value length is smaller than same-represented char(y) value because input varchar has length lower than y. + * We decrement last character in input (in fact, we decrement last non-zero character) and pad the value with + * max code point up to y characters. + */ + codePoints = trimTrailing(codePoints, '\0'); + + if (codePoints.isEmpty()) { + // No non-zero characters in input and input is shorter than y. Input value is smaller than any char(4) casted back to varchar, so we return the smallest char(4) possible + return codePointsToSliceUtf8(nCopies(toIntExact(y), (int) '\0')); + } + + codePoints = new ArrayList<>(codePoints); + codePoints.set(codePoints.size() - 1, codePoints.get(codePoints.size() - 1) - 1); + codePoints.addAll(nCopies(toIntExact(y) - codePoints.size(), Character.MAX_CODE_POINT)); + + verify(codePoints.get(codePoints.size() - 1) != ' '); // no trailing spaces to trim + + return codePointsToSliceUtf8(codePoints); + } + + private static List trimTrailing(List codePoints, int codePointToTrim) + { + int endIndex = codePoints.size(); + while (endIndex > 0 && codePoints.get(endIndex - 1) == codePointToTrim) { + endIndex--; + } + return ImmutableList.copyOf(codePoints.subList(0, endIndex)); + } + + private static List toCodePoints(Slice slice) + { + ImmutableList.Builder codePoints = ImmutableList.builder(); + for (int offset = 0; offset < slice.length(); ) { + int codePoint = getCodePointAt(slice, offset); + offset += lengthOfCodePoint(slice, offset); + codePoints.add(codePoint); + } + return codePoints.build(); + } + + private static Slice codePointsToSliceUtf8(List codePoints) + { + int length = codePoints.stream() + .mapToInt(SliceUtf8::lengthOfCodePoint) + .sum(); + + Slice result = Slices.wrappedBuffer(new byte[length]); + int offset = 0; + for (int codePoint : codePoints) { + setCodePointAt(codePoint, result, offset); + offset += lengthOfCodePoint(codePoint); } - // if Varchar(x) value length (including spaces) is smaller than y, we truncate all spaces - // and also remove one additional trailing character to get smaller Char(y) value - return trimmedSlice.slice(0, offsetOfCodePoint(trimmedSlice, trimmedTextLength - 1)); + return result; } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java index 0922d48208a09..7e5288458bba8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.type.StandardTypes; import io.airlift.concurrent.ThreadLocalCache; import io.airlift.slice.Slice; +import io.airlift.units.Duration; import org.joda.time.DateTime; import org.joda.time.DateTimeField; import org.joda.time.Days; @@ -1075,4 +1076,18 @@ else if (character == '%') { throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e); } } + + @Description("convert duration string to an interval") + @ScalarFunction("parse_duration") + @LiteralParameters("x") + @SqlType(StandardTypes.INTERVAL_DAY_TO_SECOND) + public static long parseDuration(@SqlType("varchar(x)") Slice duration) + { + try { + return Duration.valueOf(duration.toStringUtf8()).toMillis(); + } + catch (IllegalArgumentException e) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java index 54c4ddc3bc8c4..226dbb9099966 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/EmptyMapConstructor.java @@ -14,26 +14,32 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.google.common.collect.ImmutableList; - -import static com.facebook.presto.type.UnknownType.UNKNOWN; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.Type; public final class EmptyMapConstructor { - private static final Block EMPTY_MAP = new InterleavedBlockBuilder(ImmutableList.of(UNKNOWN, UNKNOWN), new BlockBuilderStatus(), 0).build(); + private final Block emptyMap; - private EmptyMapConstructor() {} + public EmptyMapConstructor(@TypeParameter("map(unknown,unknown)") Type mapType) + { + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + mapBlockBuilder.beginBlockEntry(); + mapBlockBuilder.closeEntry(); + emptyMap = ((MapType) mapType).getObject(mapBlockBuilder.build(), 0); + } @Description("Creates an empty map") @ScalarFunction @SqlType("map(unknown,unknown)") - public static Block map() + public Block map() { - return EMPTY_MAP; + return emptyMap; } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/GroupingOperationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/GroupingOperationFunction.java new file mode 100644 index 0000000000000..9550e3caec597 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/GroupingOperationFunction.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.StandardTypes; + +import java.util.List; + +import static com.facebook.presto.spi.type.StandardTypes.BIGINT; + +public final class GroupingOperationFunction +{ + public static final String BIGINT_GROUPING = "bigint_grouping"; + public static final String INTEGER_GROUPING = "integer_grouping"; + public static final int MAX_NUMBER_GROUPING_ARGUMENTS_BIGINT = 63; + public static final int MAX_NUMBER_GROUPING_ARGUMENTS_INTEGER = 31; + + private GroupingOperationFunction() {} + + /** + * The grouping function is used in conjunction with GROUPING SETS, ROLLUP and CUBE to + * indicate which columns are present in that grouping. + * + *

The grouping function must be invoked with arguments that exactly match the columns + * referenced in the corresponding GROUPING SET, ROLLUP or CUBE clause at the associated + * query level. Those column arguments are not evaluated and instead the function is + * re-written with the arguments below. + * + *

To compute the resulting bit set for a particular row, bits are assigned to the + * argument columns with the rightmost column being the most significant bit. For a + * given grouping, a bit is set to 0 if the corresponding column is included in the + * grouping and 1 otherwise. For an example, see the SQL documentation for the + * function. + * + * @param groupId An ordinal indicating which grouping is currently being processed. + * Each grouping is assigned a unique monotonically increasing integer. + * @param columns The column arguments with which the function was + * invoked converted to ordinals with respect to the base table column + * ordering. + * @param groupingSetDescriptors A collection of ordinal lists where the index of + * the list is the groupId and the list itself contains the ordinals of the + * columns present in the grouping. For example: [[0, 2], [2], [0, 1, 2]] + * means the the 0th list contains the set of columns that are present in + * the 0th grouping. + * @return A bit set converted to decimal indicating which columns are present in + * the grouping. If a column is NOT present in the grouping its corresponding + * bit is set to 1 and to 0 if the column is present in the grouping. + */ + @ScalarFunction(value = INTEGER_GROUPING, deterministic = false) + @SqlType(StandardTypes.INTEGER) + public static long integerGrouping( + @SqlType(BIGINT) long groupId, + @SqlType("ListLiteral") List columns, + @SqlType("ListLiteral") List> groupingSetDescriptors) + { + return calculateGrouping(groupId, columns, groupingSetDescriptors); + } + + @ScalarFunction(value = BIGINT_GROUPING, deterministic = false) + @SqlType(StandardTypes.BIGINT) + public static long bigintGrouping( + @SqlType(BIGINT) long groupId, + @SqlType("ListLiteral") List columns, + @SqlType("ListLiteral") List> groupingSetDescriptors) + { + return calculateGrouping(groupId, columns, groupingSetDescriptors); + } + + private static long calculateGrouping(long groupId, List columns, List> groupingSetDescriptors) + { + long grouping = (1L << columns.size()) - 1; + + List groupingSet = groupingSetDescriptors.get((int) groupId); + for (Integer groupingColumn : groupingSet) { + int index = columns.indexOf(groupingColumn); + if (index != -1) { + // Leftmost argument to grouping() (i.e. when index = 0) corresponds to + // the most significant bit in the result. That is why we shift 1L starting + // from the columns.size() - 1 bit index. + grouping = grouping & ~(1L << (columns.size() - 1 - index)); + } + } + + return grouping; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/InvokeFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/InvokeFunction.java index 9472ff3d9613b..cb3c4573ce8a2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/InvokeFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/InvokeFunction.java @@ -21,10 +21,12 @@ import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.gen.lambda.LambdaFunctionInterface; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; @@ -39,7 +41,7 @@ public final class InvokeFunction { public static final InvokeFunction INVOKE_FUNCTION = new InvokeFunction(); - private static final MethodHandle METHOD_HANDLE = methodHandle(InvokeFunction.class, "invoke", MethodHandle.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(InvokeFunction.class, "invoke", InvokeLambda.class); private InvokeFunction() { @@ -78,19 +80,27 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in return new ScalarFunctionImplementation( true, ImmutableList.of(false), + ImmutableList.of(false), + ImmutableList.of(Optional.of(InvokeLambda.class)), METHOD_HANDLE.asType( METHOD_HANDLE.type() .changeReturnType(wrap(returnType.getJavaType()))), isDeterministic()); } - public static Object invoke(MethodHandle function) + public static Object invoke(InvokeLambda function) { try { - return function.invoke(); + return function.apply(); } catch (Throwable throwable) { throw Throwables.propagate(throwable); } } + + @FunctionalInterface + public interface InvokeLambda extends LambdaFunctionInterface + { + Object apply(); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonOperators.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonOperators.java index 746adb2d0d2a9..c5c8d6fa9b838 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonOperators.java @@ -45,7 +45,10 @@ import static com.facebook.presto.spi.type.StandardTypes.DOUBLE; import static com.facebook.presto.spi.type.StandardTypes.INTEGER; import static com.facebook.presto.spi.type.StandardTypes.JSON; +import static com.facebook.presto.spi.type.StandardTypes.REAL; +import static com.facebook.presto.spi.type.StandardTypes.SMALLINT; import static com.facebook.presto.spi.type.StandardTypes.TIMESTAMP; +import static com.facebook.presto.spi.type.StandardTypes.TINYINT; import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; import static com.facebook.presto.util.DateTimeUtils.printDate; import static com.facebook.presto.util.DateTimeUtils.printTimestampWithoutTimeZone; @@ -53,6 +56,8 @@ import static com.facebook.presto.util.JsonUtil.createJsonGenerator; import static com.facebook.presto.util.JsonUtil.createJsonParser; import static com.fasterxml.jackson.core.JsonFactory.Feature.CANONICALIZE_FIELD_NAMES; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; import static java.lang.Math.toIntExact; import static java.lang.String.format; @@ -222,6 +227,46 @@ public static Double castToDouble(@SqlType(JSON) Slice json) } } + @ScalarOperator(CAST) + @SqlNullable + @SqlType(REAL) + public static Long castToReal(@SqlType(JSON) Slice json) + { + try (JsonParser parser = createJsonParser(JSON_FACTORY, json)) { + parser.nextToken(); + Long result; + switch (parser.getCurrentToken()) { + case VALUE_NULL: + result = null; + break; + case VALUE_STRING: + result = VarcharOperators.castToFloat(Slices.utf8Slice(parser.getText())); + break; + case VALUE_NUMBER_FLOAT: + result = (long) floatToRawIntBits(parser.getFloatValue()); + break; + case VALUE_NUMBER_INT: + // An alternative is calling getLongValue and then BigintOperators.castToReal. + // It doesn't work as well because it can result in overflow and underflow exceptions for large integral numbers. + result = (long) floatToRawIntBits(parser.getFloatValue()); + break; + case VALUE_TRUE: + result = BooleanOperators.castToReal(true); + break; + case VALUE_FALSE: + result = BooleanOperators.castToReal(false); + break; + default: + throw new PrestoException(INVALID_CAST_ARGUMENT, format("Cannot cast '%s' to %s", json.toStringUtf8(), REAL)); + } + checkCondition(parser.nextToken() == null, INVALID_CAST_ARGUMENT, "Cannot cast input json to REAL"); // check no trailing token + return result; + } + catch (IOException e) { + throw new PrestoException(INVALID_CAST_ARGUMENT, format("Cannot cast '%s' to %s", json.toStringUtf8(), REAL)); + } + } + @ScalarOperator(CAST) @SqlNullable @SqlType(BOOLEAN) @@ -278,13 +323,42 @@ public static Slice castFromVarchar(@SqlType("varchar(x)") Slice value) } } + @ScalarOperator(CAST) + @SqlType(JSON) + public static Slice castFromTinyInt(@SqlType(TINYINT) long value) + throws IOException + { + return internalCastFromLong(value, 4); + } + + @ScalarOperator(CAST) + @SqlType(JSON) + public static Slice castFromSmallInt(@SqlType(SMALLINT) long value) + throws IOException + { + return internalCastFromLong(value, 8); + } + @ScalarOperator(CAST) @SqlType(JSON) public static Slice castFromInteger(@SqlType(INTEGER) long value) throws IOException + { + return internalCastFromLong(value, 12); + } + + @ScalarOperator(CAST) + @SqlType(JSON) + public static Slice castFromBigint(@SqlType(BIGINT) long value) + throws IOException + { + return internalCastFromLong(value, 20); + } + + private static Slice internalCastFromLong(long value, int estimatedSize) { try { - SliceOutput output = new DynamicSliceOutput(20); + SliceOutput output = new DynamicSliceOutput(estimatedSize); try (JsonGenerator jsonGenerator = createJsonGenerator(JSON_FACTORY, output)) { jsonGenerator.writeNumber(value); } @@ -297,11 +371,11 @@ public static Slice castFromInteger(@SqlType(INTEGER) long value) @ScalarOperator(CAST) @SqlType(JSON) - public static Slice castFromBigint(@SqlType(BIGINT) long value) + public static Slice castFromDouble(@SqlType(DOUBLE) double value) throws IOException { try { - SliceOutput output = new DynamicSliceOutput(20); + SliceOutput output = new DynamicSliceOutput(32); try (JsonGenerator jsonGenerator = createJsonGenerator(JSON_FACTORY, output)) { jsonGenerator.writeNumber(value); } @@ -314,13 +388,13 @@ public static Slice castFromBigint(@SqlType(BIGINT) long value) @ScalarOperator(CAST) @SqlType(JSON) - public static Slice castFromDouble(@SqlType(DOUBLE) double value) + public static Slice castFromReal(@SqlType(REAL) long value) throws IOException { try { SliceOutput output = new DynamicSliceOutput(32); try (JsonGenerator jsonGenerator = createJsonGenerator(JSON_FACTORY, output)) { - jsonGenerator.writeNumber(value); + jsonGenerator.writeNumber(intBitsToFloat((int) value)); } return output.slice(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java index ebd161623c957..97a133ea68e5f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToArrayCast.java @@ -23,11 +23,11 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java index 89ab4680f25a1..aee7ff4e835b2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/JsonToMapCast.java @@ -22,13 +22,12 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; @@ -50,7 +49,7 @@ public class JsonToMapCast extends SqlOperator { public static final JsonToMapCast JSON_TO_MAP = new JsonToMapCast(); - private static final MethodHandle METHOD_HANDLE = methodHandle(JsonToMapCast.class, "toMap", Type.class, ConnectorSession.class, Slice.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(JsonToMapCast.class, "toMap", MapType.class, ConnectorSession.class, Slice.class); private JsonToMapCast() { @@ -67,28 +66,30 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in checkArgument(arity == 1, "Expected arity to be 1"); Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - Type mapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of(TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); + MapType mapType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of(TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); checkCondition(canCastFromJson(mapType), INVALID_CAST_ARGUMENT, "Cannot cast JSON to %s", mapType); MethodHandle methodHandle = METHOD_HANDLE.bindTo(mapType); return new ScalarFunctionImplementation(true, ImmutableList.of(false), methodHandle, isDeterministic()); } @UsedByGeneratedCode - public static Block toMap(Type mapType, ConnectorSession connectorSession, Slice json) + public static Block toMap(MapType mapType, ConnectorSession connectorSession, Slice json) { try { Map map = (Map) stackRepresentationToObject(connectorSession, json, mapType); if (map == null) { return null; } - Type keyType = ((MapType) mapType).getKeyType(); - Type valueType = ((MapType) mapType).getValueType(); - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueType), new BlockBuilderStatus(), map.size() * 2); + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); for (Map.Entry entry : map.entrySet()) { appendToBlockBuilder(keyType, entry.getKey(), blockBuilder); appendToBlockBuilder(valueType, entry.getValue(), blockBuilder); } - return blockBuilder.build(); + mapBlockBuilder.closeEntry(); + return mapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); } catch (RuntimeException e) { throw new PrestoException(INVALID_CAST_ARGUMENT, "Value cannot be cast to " + mapType, e); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ListLiteralCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ListLiteralCast.java new file mode 100644 index 0000000000000..b34b86de18c22 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ListLiteralCast.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.type.ListLiteralType; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static com.facebook.presto.spi.type.IntegerType.INTEGER; + +public class ListLiteralCast +{ + private ListLiteralCast() + { + } + + @ScalarOperator(OperatorType.CAST) + @SqlType(ListLiteralType.NAME) + public static List castArrayToListLiteral(@SqlType("array(integer)") Block array) + { + ImmutableList.Builder listBuilder = ImmutableList.builder(); + for (int i = 0; i < array.getPositionCount(); i++) { + listBuilder.add(array.getInt(i, 0)); + } + + return listBuilder.build(); + } + + @ScalarOperator(OperatorType.CAST) + @SqlType(ListLiteralType.NAME) + public static List> castArrayOfArraysToListLiteral(@SqlType("array(array(integer))") Block arrayOfArrays) + { + ArrayType arrayType = new ArrayType(INTEGER); + ImmutableList.Builder> outerListBuilder = ImmutableList.builder(); + for (int i = 0; i < arrayOfArrays.getPositionCount(); i++) { + Block subArray = arrayType.getObject(arrayOfArrays, i); + outerListBuilder.add(castArrayToListLiteral(subArray)); + } + + return outerListBuilder.build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java index 13d35aded9356..dca56aa4c4241 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java @@ -24,9 +24,11 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.InterleavedBlock; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.VarArgsToArrayAdapterGenerator.MethodHandleAndConstructor; import com.google.common.collect.ImmutableList; @@ -49,8 +51,8 @@ public final class MapConcatFunction private static final String FUNCTION_NAME = "map_concat"; private static final String DESCRIPTION = "Concatenates given maps"; - private static final MethodHandle USER_STATE_FACTORY = methodHandle(MapConcatFunction.class, "createMapState", Type.class, Type.class); - private static final MethodHandle METHOD_HANDLE = methodHandle(MapConcatFunction.class, "mapConcat", Type.class, Type.class, Object.class, Block[].class); + private static final MethodHandle USER_STATE_FACTORY = methodHandle(MapConcatFunction.class, "createMapState", MapType.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(MapConcatFunction.class, "mapConcat", MapType.class, Object.class, Block[].class); private MapConcatFunction() { @@ -90,31 +92,35 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); + MapType mapType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); MethodHandleAndConstructor methodHandleAndConstructor = generateVarArgsToArrayAdapter( Block.class, Block.class, arity, - METHOD_HANDLE.bindTo(keyType).bindTo(valueType), - USER_STATE_FACTORY.bindTo(keyType).bindTo(valueType)); + METHOD_HANDLE.bindTo(mapType), + USER_STATE_FACTORY.bindTo(mapType)); return new ScalarFunctionImplementation( false, nCopies(arity, false), nCopies(arity, false), + nCopies(arity, Optional.empty()), methodHandleAndConstructor.getMethodHandle(), Optional.of(methodHandleAndConstructor.getConstructor()), isDeterministic()); } @UsedByGeneratedCode - public static Object createMapState(Type keyType, Type valueType) + public static Object createMapState(MapType mapType) { - return new PageBuilder(ImmutableList.of(keyType, valueType)); + return new PageBuilder(ImmutableList.of(mapType)); } @UsedByGeneratedCode - public static Block mapConcat(Type keyType, Type valueType, Object state, Block[] maps) + public static Block mapConcat(MapType mapType, Object state, Block[] maps) { int entries = 0; int lastMapIndex = maps.length - 1; @@ -136,17 +142,19 @@ public static Block mapConcat(Type keyType, Type valueType, Object state, Block[ } // TODO: we should move TypedSet into user state as well + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); TypedSet typedSet = new TypedSet(keyType, entries / 2); - BlockBuilder keyBlockBuilder = pageBuilder.getBlockBuilder(0); - BlockBuilder valueBlockBuilder = pageBuilder.getBlockBuilder(1); + BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); // the last map Block map = maps[lastMapIndex]; int total = 0; for (int i = 0; i < map.getPositionCount(); i += 2) { typedSet.add(map, i); - keyType.appendTo(map, i, keyBlockBuilder); - valueType.appendTo(map, i + 1, valueBlockBuilder); + keyType.appendTo(map, i, blockBuilder); + valueType.appendTo(map, i + 1, blockBuilder); total++; } // the map between the last and the first @@ -155,8 +163,8 @@ public static Block mapConcat(Type keyType, Type valueType, Object state, Block[ for (int i = 0; i < map.getPositionCount(); i += 2) { if (!typedSet.contains(map, i)) { typedSet.add(map, i); - keyType.appendTo(map, i, keyBlockBuilder); - valueType.appendTo(map, i + 1, valueBlockBuilder); + keyType.appendTo(map, i, blockBuilder); + valueType.appendTo(map, i + 1, blockBuilder); total++; } } @@ -165,16 +173,14 @@ public static Block mapConcat(Type keyType, Type valueType, Object state, Block[ map = maps[firstMapIndex]; for (int i = 0; i < map.getPositionCount(); i += 2) { if (!typedSet.contains(map, i)) { - keyType.appendTo(map, i, keyBlockBuilder); - valueType.appendTo(map, i + 1, valueBlockBuilder); + keyType.appendTo(map, i, blockBuilder); + valueType.appendTo(map, i + 1, blockBuilder); total++; } } - pageBuilder.declarePositions(total); - Block[] blocks = new Block[2]; - blocks[0] = keyBlockBuilder.getRegion(keyBlockBuilder.getPositionCount() - total, total); - blocks[1] = valueBlockBuilder.getRegion(valueBlockBuilder.getPositionCount() - total, total); - return new InterleavedBlock(blocks); + mapBlockBuilder.closeEntry(); + pageBuilder.declarePosition(); + return mapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java index 577c2546e1adc..362618ae94b57 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConstructor.java @@ -19,25 +19,27 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.metadata.Signature.comparableTypeParameter; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.StandardTypes.MAP; import static com.facebook.presto.util.Failures.checkCondition; +import static com.facebook.presto.util.Reflection.constructorMethodHandle; import static com.facebook.presto.util.Reflection.methodHandle; public final class MapConstructor @@ -45,7 +47,7 @@ public final class MapConstructor { public static final MapConstructor MAP_CONSTRUCTOR = new MapConstructor(); - private static final MethodHandle METHOD_HANDLE = methodHandle(MapConstructor.class, "createMap", MapType.class, Block.class, Block.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(MapConstructor.class, "createMap", MapType.class, MethodHandle.class, MethodHandle.class, State.class, Block.class, Block.class); private static final String DESCRIPTION = "Constructs a map from the given key/value arrays"; public MapConstructor() @@ -85,14 +87,30 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type valueType = boundVariables.getTypeVariable("V"); Type mapType = typeManager.getParameterizedType(MAP, ImmutableList.of(TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); - return new ScalarFunctionImplementation(false, ImmutableList.of(false, false), METHOD_HANDLE.bindTo(mapType), isDeterministic()); + MethodHandle keyHashCode = functionRegistry.getScalarFunctionImplementation(functionRegistry.resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(keyType))).getMethodHandle(); + MethodHandle keyEqual = functionRegistry.getScalarFunctionImplementation(functionRegistry.resolveOperator(OperatorType.EQUAL, ImmutableList.of(keyType, keyType))).getMethodHandle(); + MethodHandle instanceFactory = constructorMethodHandle(State.class, MapType.class).bindTo(mapType); + + return new ScalarFunctionImplementation( + false, + ImmutableList.of(false, false), + ImmutableList.of(false, false), + ImmutableList.of(Optional.empty(), Optional.empty()), + METHOD_HANDLE.bindTo(mapType).bindTo(keyEqual).bindTo(keyHashCode), + Optional.of(instanceFactory), + isDeterministic()); } @UsedByGeneratedCode - public static Block createMap(MapType mapType, Block keyBlock, Block valueBlock) + public static Block createMap(MapType mapType, MethodHandle keyEqual, MethodHandle keyHashCode, State state, Block keyBlock, Block valueBlock) { - BlockBuilder blockBuilder = new InterleavedBlockBuilder(mapType.getTypeParameters(), new BlockBuilderStatus(), keyBlock.getPositionCount() * 2); + PageBuilder pageBuilder = state.getPageBuilder(); + if (pageBuilder.isFull()) { + pageBuilder.reset(); + } + BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); checkCondition(keyBlock.getPositionCount() == valueBlock.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "Key and value arrays must be the same length"); for (int i = 0; i < keyBlock.getPositionCount(); i++) { if (keyBlock.isNull(i)) { @@ -101,7 +119,24 @@ public static Block createMap(MapType mapType, Block keyBlock, Block valueBlock) mapType.getKeyType().appendTo(keyBlock, i, blockBuilder); mapType.getValueType().appendTo(valueBlock, i, blockBuilder); } + mapBlockBuilder.closeEntry(); + pageBuilder.declarePosition(); - return blockBuilder.build(); + return mapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); + } + + public static final class State + { + private final PageBuilder pageBuilder; + + public State(MapType mapType) + { + pageBuilder = new PageBuilder(ImmutableList.of(mapType)); + } + + public PageBuilder getPageBuilder() + { + return pageBuilder; + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java index 07ef7ff87a517..2af51b3dded50 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapFilterFunction.java @@ -31,16 +31,19 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; -import com.facebook.presto.type.MapType; +import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; import java.util.List; +import java.util.Optional; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; @@ -105,17 +108,21 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in { Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); + Type mapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); return new ScalarFunctionImplementation( false, ImmutableList.of(false, false), - generateFilter(keyType, valueType), + ImmutableList.of(false, false), + ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), + generateFilter(keyType, valueType, mapType), isDeterministic()); } - private static MethodHandle generateFilter(Type keyType, Type valueType) + private static MethodHandle generateFilter(Type keyType, Type valueType, Type mapType) { CallSiteBinder binder = new CallSiteBinder(); - MapType mapType = new MapType(keyType, valueType); Class keyJavaType = Primitives.wrap(keyType.getJavaType()); Class valueJavaType = Primitives.wrap(valueType.getJavaType()); @@ -126,7 +133,7 @@ private static MethodHandle generateFilter(Type keyType, Type valueType) definition.declareDefaultConstructor(a(PRIVATE)); Parameter block = arg("block", Block.class); - Parameter function = arg("function", MethodHandle.class); + Parameter function = arg("function", BinaryFunctionInterface.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), "filter", @@ -181,7 +188,7 @@ private static MethodHandle generateFilter(Type keyType, Type valueType) .body(new BytecodeBlock() .append(loadKeyElement) .append(loadValueElement) - .append(keep.set(function.invoke("invokeExact", Boolean.class, keyElement, valueElement))) + .append(keep.set(function.invoke("apply", Object.class, keyElement.cast(Object.class), valueElement.cast(Object.class)).cast(Boolean.class))) .append(new IfStatement("if (keep != null && keep) ...") .condition(and(notEqual(keep, constantNull(Boolean.class)), keep.cast(boolean.class))) .ifTrue(new BytecodeBlock() @@ -191,6 +198,6 @@ private static MethodHandle generateFilter(Type keyType, Type valueType) body.append(blockBuilder.invoke("build", Block.class).ret()); Class generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapFilterFunction.class.getClassLoader()); - return methodHandle(generatedClass, "filter", Block.class, MethodHandle.class); + return methodHandle(generatedClass, "filter", Block.class, BinaryFunctionInterface.class); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java index e8eeb3fce1f11..31a8ec215b2c5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapSubscriptOperator.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.SingleMapBlock; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.Type; @@ -48,15 +49,16 @@ public class MapSubscriptOperator extends SqlOperator { - private static final MethodHandle METHOD_HANDLE_BOOLEAN = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, boolean.class); - private static final MethodHandle METHOD_HANDLE_LONG = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, long.class); - private static final MethodHandle METHOD_HANDLE_DOUBLE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, double.class); - private static final MethodHandle METHOD_HANDLE_SLICE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Slice.class); - private static final MethodHandle METHOD_HANDLE_OBJECT = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Object.class); + private static final MethodHandle METHOD_HANDLE_BOOLEAN = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, boolean.class); + private static final MethodHandle METHOD_HANDLE_LONG = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, long.class); + private static final MethodHandle METHOD_HANDLE_DOUBLE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, double.class); + private static final MethodHandle METHOD_HANDLE_SLICE = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Slice.class); + private static final MethodHandle METHOD_HANDLE_OBJECT = methodHandle(MapSubscriptOperator.class, "subscript", boolean.class, boolean.class, FunctionInvoker.class, MethodHandle.class, Type.class, Type.class, ConnectorSession.class, Block.class, Object.class); private final boolean legacyMissingKey; + private final boolean useNewMapBlock; - public MapSubscriptOperator(boolean legacyMissingKey) + public MapSubscriptOperator(boolean legacyMissingKey, boolean useNewMapBlock) { super(SUBSCRIPT, ImmutableList.of(typeVariable("K"), typeVariable("V")), @@ -64,6 +66,7 @@ public MapSubscriptOperator(boolean legacyMissingKey) parseTypeSignature("V"), ImmutableList.of(parseTypeSignature("map(K,V)"), parseTypeSignature("K"))); this.legacyMissingKey = legacyMissingKey; + this.useNewMapBlock = useNewMapBlock; } @Override @@ -90,7 +93,7 @@ else if (keyType.getJavaType() == Slice.class) { else { methodHandle = METHOD_HANDLE_OBJECT; } - methodHandle = MethodHandles.insertArguments(methodHandle, 0, legacyMissingKey); + methodHandle = MethodHandles.insertArguments(methodHandle, 0, legacyMissingKey, useNewMapBlock); FunctionInvoker functionInvoker = new FunctionInvoker(functionRegistry); methodHandle = methodHandle.bindTo(functionInvoker).bindTo(keyEqualsMethod).bindTo(keyType).bindTo(valueType); @@ -106,8 +109,20 @@ else if (keyType.getJavaType() == Slice.class) { } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, boolean key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, boolean key) { + if (map instanceof SingleMapBlock && useNewMapBlock) { + SingleMapBlock mapBlock = (SingleMapBlock) map; + int valuePosition = mapBlock.seekKeyExact(key); + if (valuePosition == -1) { + if (legacyMissingKey) { + return null; + } + throw throwMissingKeyException(keyType, functionInvoker, key, session); + } + return readNativeValue(valueType, mapBlock, valuePosition); + } + // TODO: assume that map is always instanceof SingleMapBlock once all map producing code is updated. for (int position = 0; position < map.getPositionCount(); position += 2) { try { if ((boolean) keyEqualsMethod.invokeExact(keyType.getBoolean(map, position), key)) { @@ -127,8 +142,20 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, long key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, long key) { + if (map instanceof SingleMapBlock && useNewMapBlock) { + SingleMapBlock mapBlock = (SingleMapBlock) map; + int valuePosition = mapBlock.seekKeyExact(key); + if (valuePosition == -1) { + if (legacyMissingKey) { + return null; + } + throw throwMissingKeyException(keyType, functionInvoker, key, session); + } + return readNativeValue(valueType, mapBlock, valuePosition); + } + // TODO: assume that map is always instanceof SingleMapBlock once all map producing code is updated. for (int position = 0; position < map.getPositionCount(); position += 2) { try { if ((boolean) keyEqualsMethod.invokeExact(keyType.getLong(map, position), key)) { @@ -148,8 +175,20 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, double key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, double key) { + if (map instanceof SingleMapBlock && useNewMapBlock) { + SingleMapBlock mapBlock = (SingleMapBlock) map; + int valuePosition = mapBlock.seekKeyExact(key); + if (valuePosition == -1) { + if (legacyMissingKey) { + return null; + } + throw throwMissingKeyException(keyType, functionInvoker, key, session); + } + return readNativeValue(valueType, mapBlock, valuePosition); + } + // TODO: assume that map is always instanceof SingleMapBlock once all map producing code is updated. for (int position = 0; position < map.getPositionCount(); position += 2) { try { if ((boolean) keyEqualsMethod.invokeExact(keyType.getDouble(map, position), key)) { @@ -169,8 +208,20 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Slice key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Slice key) { + if (map instanceof SingleMapBlock && useNewMapBlock) { + SingleMapBlock mapBlock = (SingleMapBlock) map; + int valuePosition = mapBlock.seekKeyExact(key); + if (valuePosition == -1) { + if (legacyMissingKey) { + return null; + } + throw throwMissingKeyException(keyType, functionInvoker, key, session); + } + return readNativeValue(valueType, mapBlock, valuePosition); + } + // TODO: assume that map is always instanceof SingleMapBlock once all map producing code is updated. for (int position = 0; position < map.getPositionCount(); position += 2) { try { if ((boolean) keyEqualsMethod.invokeExact(keyType.getSlice(map, position), key)) { @@ -190,8 +241,20 @@ public static Object subscript(boolean legacyMissingKey, FunctionInvoker functio } @UsedByGeneratedCode - public static Object subscript(boolean legacyMissingKey, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Object key) + public static Object subscript(boolean legacyMissingKey, boolean useNewMapBlock, FunctionInvoker functionInvoker, MethodHandle keyEqualsMethod, Type keyType, Type valueType, ConnectorSession session, Block map, Object key) { + if (map instanceof SingleMapBlock && useNewMapBlock) { + SingleMapBlock mapBlock = (SingleMapBlock) map; + int valuePosition = mapBlock.seekKeyExact((Block) key); + if (valuePosition == -1) { + if (legacyMissingKey) { + return null; + } + throw throwMissingKeyException(keyType, functionInvoker, key, session); + } + return readNativeValue(valueType, mapBlock, valuePosition); + } + // TODO: assume that map is always instanceof SingleMapBlock once all map producing code is updated. for (int position = 0; position < map.getPositionCount(); position += 2) { try { if ((boolean) keyEqualsMethod.invoke(keyType.getObject(map, position), key)) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToJsonCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToJsonCast.java index f2a12b9e6f93e..2b6d01444860c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToJsonCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToJsonCast.java @@ -23,7 +23,7 @@ import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.fasterxml.jackson.core.JsonGenerator; import com.google.common.collect.ImmutableList; import io.airlift.slice.DynamicSliceOutput; @@ -69,7 +69,9 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in checkArgument(arity == 1, "Expected arity to be 1"); Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V"); - Type mapType = new MapType(keyType, valueType); + Type mapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); checkCondition(canCastToJson(mapType), INVALID_CAST_ARGUMENT, "Cannot cast %s to JSON", mapType); ObjectKeyProvider provider = ObjectKeyProvider.createObjectKeyProvider(keyType); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java index 3aa6494dca608..e52e3cc8e0f11 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java @@ -20,7 +20,6 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; @@ -28,7 +27,6 @@ import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -56,6 +54,7 @@ public static Block toMap( @TypeParameter("FV") Type fromValueType, @TypeParameter("TK") Type toKeyType, @TypeParameter("TV") Type toValueType, + @TypeParameter("map(TK,TV)") Type toMapType, ConnectorSession session, @SqlType("map(FK,FV)") Block fromMap) { @@ -94,7 +93,9 @@ public static Block toMap( } } Block keyBlock = keyBlockBuilder.build(); - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(toKeyType, toValueType), new BlockBuilderStatus(), fromMap.getPositionCount()); + + BlockBuilder mapBlockBuilder = toMapType.createBlockBuilder(new BlockBuilderStatus(), 1); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); for (int i = 0; i < fromMap.getPositionCount(); i += 2) { if (!typedSet.contains(keyBlock, i / 2)) { typedSet.add(keyBlock, i / 2); @@ -120,6 +121,8 @@ public static Block toMap( throw new PrestoException(StandardErrorCode.INVALID_CAST_ARGUMENT, "duplicate keys"); } } - return blockBuilder.build(); + + mapBlockBuilder.closeEntry(); + return (Block) toMapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java index 82da40e9a2965..d0035c1ac6878 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformKeyFunction.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.annotation.UsedByGeneratedCode; import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; @@ -30,21 +31,23 @@ import com.facebook.presto.operator.aggregation.TypedSet; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.CallSiteBinder; import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; -import com.facebook.presto.type.MapType; +import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; -import java.util.List; +import java.util.Optional; import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; @@ -66,6 +69,7 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.subtract; import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -78,6 +82,7 @@ public final class MapTransformKeyFunction extends SqlScalarFunction { public static final MapTransformKeyFunction MAP_TRANSFORM_KEY_FUNCTION = new MapTransformKeyFunction(); + private static final MethodHandle STATE_FACTORY = methodHandle(MapTransformKeyFunction.class, "createState", MapType.class); private MapTransformKeyFunction() { @@ -115,17 +120,28 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K1"); Type transformedKeyType = boundVariables.getTypeVariable("K2"); Type valueType = boundVariables.getTypeVariable("V"); + MapType resultMapType = (MapType) typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(transformedKeyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); return new ScalarFunctionImplementation( false, ImmutableList.of(false, false), - generateTransformKey(keyType, transformedKeyType, valueType), + ImmutableList.of(false, false), + ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), + generateTransformKey(keyType, transformedKeyType, valueType, resultMapType), + Optional.of(STATE_FACTORY.bindTo(resultMapType)), isDeterministic()); } - private static MethodHandle generateTransformKey(Type keyType, Type transformedKeyType, Type valueType) + @UsedByGeneratedCode + public static Object createState(MapType mapType) + { + return new PageBuilder(ImmutableList.of(mapType)); + } + + private static MethodHandle generateTransformKey(Type keyType, Type transformedKeyType, Type valueType, Type resultMapType) { CallSiteBinder binder = new CallSiteBinder(); - MapType mapType = new MapType(transformedKeyType, valueType); Class keyJavaType = Primitives.wrap(keyType.getJavaType()); Class transformedKeyJavaType = Primitives.wrap(transformedKeyType.getJavaType()); Class valueJavaType = Primitives.wrap(valueType.getJavaType()); @@ -136,19 +152,22 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK type(Object.class)); definition.declareDefaultConstructor(a(PRIVATE)); + Parameter state = arg("state", Object.class); Parameter session = arg("session", ConnectorSession.class); Parameter block = arg("block", Block.class); - Parameter function = arg("function", MethodHandle.class); + Parameter function = arg("function", BinaryFunctionInterface.class); MethodDefinition method = definition.declareMethod( a(PUBLIC, STATIC), "transform", type(Block.class), - ImmutableList.of(session, block, function)); + ImmutableList.of(state, session, block, function)); BytecodeBlock body = method.getBody(); Scope scope = method.getScope(); Variable positionCount = scope.declareVariable(int.class, "positionCount"); Variable position = scope.declareVariable(int.class, "position"); + Variable pageBuilder = scope.declareVariable(PageBuilder.class, "pageBuilder"); + Variable mapBlockBuilder = scope.declareVariable(BlockBuilder.class, "mapBlockBuilder"); Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); Variable typedSet = scope.declareVariable(TypedSet.class, "typeSet"); Variable keyElement = scope.declareVariable(keyJavaType, "keyElement"); @@ -158,12 +177,13 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK // invoke block.getPositionCount() body.append(positionCount.set(block.invoke("getPositionCount", int.class))); - // create the interleaved block builder - body.append(blockBuilder.set(newInstance( - InterleavedBlockBuilder.class, - constantType(binder, mapType).invoke("getTypeParameters", List.class), - newInstance(BlockBuilderStatus.class), - positionCount))); + // prepare the single map block builder + body.append(pageBuilder.set(state.cast(PageBuilder.class))); + body.append(new IfStatement() + .condition(pageBuilder.invoke("isFull", boolean.class)) + .ifTrue(pageBuilder.invoke("reset", void.class))); + body.append(mapBlockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(0)))); + body.append(blockBuilder.set(mapBlockBuilder.invoke("beginBlockEntry", BlockBuilder.class))); // create typed set body.append(typedSet.set(newInstance( @@ -210,7 +230,7 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK BytecodeNode throwDuplicatedKeyException; if (!transformedKeyType.equals(UNKNOWN)) { writeKeyElement = new BytecodeBlock() - .append(transformedKeyElement.set(function.invoke("invokeExact", transformedKeyJavaType, keyElement, valueElement))) + .append(transformedKeyElement.set(function.invoke("apply", Object.class, keyElement.cast(Object.class), valueElement.cast(Object.class)).cast(transformedKeyJavaType))) .append(new IfStatement() .condition(equal(transformedKeyElement, constantNull(transformedKeyJavaType))) .ifTrue(throwNullKeyException) @@ -251,9 +271,18 @@ private static MethodHandle generateTransformKey(Type keyType, Type transformedK .ifTrue(throwDuplicatedKeyException) .ifFalse(typedSet.invoke("add", void.class, blockBuilder.cast(Block.class), position))))); - body.append(blockBuilder.invoke("build", Block.class).ret()); + body.append(mapBlockBuilder + .invoke("closeEntry", BlockBuilder.class) + .pop()); + body.append(constantType(binder, resultMapType) + .invoke( + "getObject", + Object.class, + mapBlockBuilder.cast(Block.class), + subtract(mapBlockBuilder.invoke("getPositionCount", int.class), constantInt(1))) + .ret()); Class generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapTransformKeyFunction.class.getClassLoader()); - return methodHandle(generatedClass, "transform", ConnectorSession.class, Block.class, MethodHandle.class); + return methodHandle(generatedClass, "transform", Object.class, ConnectorSession.class, Block.class, BinaryFunctionInterface.class); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java index feb5a77139f80..46aba75783efa 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapTransformValueFunction.java @@ -13,41 +13,71 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.annotation.UsedByGeneratedCode; +import com.facebook.presto.bytecode.BytecodeBlock; +import com.facebook.presto.bytecode.BytecodeNode; +import com.facebook.presto.bytecode.ClassDefinition; +import com.facebook.presto.bytecode.MethodDefinition; +import com.facebook.presto.bytecode.Parameter; +import com.facebook.presto.bytecode.Scope; +import com.facebook.presto.bytecode.Variable; +import com.facebook.presto.bytecode.control.ForLoop; +import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PageBuilder; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.google.common.base.Throwables; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.gen.CallSiteBinder; +import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression; +import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface; import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Primitives; import java.lang.invoke.MethodHandle; +import java.util.Optional; +import static com.facebook.presto.bytecode.Access.FINAL; +import static com.facebook.presto.bytecode.Access.PRIVATE; +import static com.facebook.presto.bytecode.Access.PUBLIC; +import static com.facebook.presto.bytecode.Access.STATIC; +import static com.facebook.presto.bytecode.Access.a; +import static com.facebook.presto.bytecode.CompilerUtils.defineClass; +import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; +import static com.facebook.presto.bytecode.Parameter.arg; +import static com.facebook.presto.bytecode.ParameterizedType.type; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.add; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.subtract; +import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable; import static com.facebook.presto.metadata.Signature.typeVariable; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; -import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue; +import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; +import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.Reflection.methodHandle; public final class MapTransformValueFunction extends SqlScalarFunction { public static final MapTransformValueFunction MAP_TRANSFORM_VALUE_FUNCTION = new MapTransformValueFunction(); - - private static final MethodHandle METHOD_HANDLE = methodHandle( - MapTransformValueFunction.class, - "transform", - Type.class, - Type.class, - Type.class, - Block.class, - MethodHandle.class); + private static final MethodHandle STATE_FACTORY = methodHandle(MapTransformKeyFunction.class, "createState", MapType.class); private MapTransformValueFunction() { @@ -85,31 +115,137 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type keyType = boundVariables.getTypeVariable("K"); Type valueType = boundVariables.getTypeVariable("V1"); Type transformedValueType = boundVariables.getTypeVariable("V2"); + Type resultMapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(transformedValueType.getTypeSignature()))); return new ScalarFunctionImplementation( false, ImmutableList.of(false, false), - METHOD_HANDLE.bindTo(keyType).bindTo(valueType).bindTo(transformedValueType), + ImmutableList.of(false, false), + ImmutableList.of(Optional.empty(), Optional.of(BinaryFunctionInterface.class)), + generateTransform(keyType, valueType, transformedValueType, resultMapType), + Optional.of(STATE_FACTORY.bindTo(resultMapType)), isDeterministic()); } - public static Block transform(Type keyType, Type valueType, Type transformedValueType, Block block, MethodHandle function) + @UsedByGeneratedCode + public static Object createState(MapType mapType) + { + return new PageBuilder(ImmutableList.of(mapType)); + } + + private static MethodHandle generateTransform(Type keyType, Type valueType, Type transformedValueType, Type resultMapType) { - int positionCount = block.getPositionCount(); - BlockBuilder resultBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, transformedValueType), new BlockBuilderStatus(), positionCount); - for (int position = 0; position < positionCount; position += 2) { - Object key = readNativeValue(keyType, block, position); - Object value = readNativeValue(valueType, block, position + 1); - Object transformedValue; - try { - transformedValue = function.invoke(key, value); - } - catch (Throwable throwable) { - throw Throwables.propagate(throwable); - } - - keyType.appendTo(block, position, resultBuilder); - writeNativeValue(transformedValueType, resultBuilder, transformedValue); + CallSiteBinder binder = new CallSiteBinder(); + Class keyJavaType = Primitives.wrap(keyType.getJavaType()); + Class valueJavaType = Primitives.wrap(valueType.getJavaType()); + Class transformedValueJavaType = Primitives.wrap(transformedValueType.getJavaType()); + + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("MapTransformValue"), + type(Object.class)); + definition.declareDefaultConstructor(a(PRIVATE)); + + // define transform method + Parameter state = arg("state", Object.class); + Parameter block = arg("block", Block.class); + Parameter function = arg("function", BinaryFunctionInterface.class); + MethodDefinition method = definition.declareMethod( + a(PUBLIC, STATIC), + "transform", + type(Block.class), + ImmutableList.of(state, block, function)); + + BytecodeBlock body = method.getBody(); + Scope scope = method.getScope(); + Variable positionCount = scope.declareVariable(int.class, "positionCount"); + Variable position = scope.declareVariable(int.class, "position"); + Variable pageBuilder = scope.declareVariable(PageBuilder.class, "pageBuilder"); + Variable mapBlockBuilder = scope.declareVariable(BlockBuilder.class, "mapBlockBuilder"); + Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder"); + Variable keyElement = scope.declareVariable(keyJavaType, "keyElement"); + Variable valueElement = scope.declareVariable(valueJavaType, "valueElement"); + Variable transformedValueElement = scope.declareVariable(transformedValueJavaType, "transformedValueElement"); + + // invoke block.getPositionCount() + body.append(positionCount.set(block.invoke("getPositionCount", int.class))); + + // prepare the single map block builder + body.append(pageBuilder.set(state.cast(PageBuilder.class))); + body.append(new IfStatement() + .condition(pageBuilder.invoke("isFull", boolean.class)) + .ifTrue(pageBuilder.invoke("reset", void.class))); + body.append(mapBlockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(0)))); + body.append(blockBuilder.set(mapBlockBuilder.invoke("beginBlockEntry", BlockBuilder.class))); + + // throw null key exception block + BytecodeNode throwNullKeyException = new BytecodeBlock() + .append(newInstance( + PrestoException.class, + getStatic(INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), "INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), + constantString("map key cannot be null"))) + .throwObject(); + + SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType); + BytecodeNode loadKeyElement; + if (!keyType.equals(UNKNOWN)) { + loadKeyElement = new BytecodeBlock().append(keyElement.set(keySqlType.getValue(block, position).cast(keyJavaType))); + } + else { + // make sure invokeExact will not take uninitialized keys during compile time + // but if we reach this point during runtime, it is an exception + loadKeyElement = new BytecodeBlock() + .append(keyElement.set(constantNull(keyJavaType))) + .append(throwNullKeyException); + } + + SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType); + BytecodeNode loadValueElement; + if (!valueType.equals(UNKNOWN)) { + loadValueElement = new IfStatement() + .condition(block.invoke("isNull", boolean.class, add(position, constantInt(1)))) + .ifTrue(valueElement.set(constantNull(valueJavaType))) + .ifFalse(valueElement.set(valueSqlType.getValue(block, add(position, constantInt(1))).cast(valueJavaType))); + } + else { + loadValueElement = new BytecodeBlock().append(valueElement.set(constantNull(valueJavaType))); } - return resultBuilder.build(); + + BytecodeNode writeTransformedValueElement; + if (!transformedValueType.equals(UNKNOWN)) { + writeTransformedValueElement = new IfStatement() + .condition(equal(transformedValueElement, constantNull(transformedValueJavaType))) + .ifTrue(blockBuilder.invoke("appendNull", BlockBuilder.class).pop()) + .ifFalse(constantType(binder, transformedValueType).writeValue(blockBuilder, transformedValueElement.cast(transformedValueType.getJavaType()))); + } + else { + writeTransformedValueElement = new BytecodeBlock().append(blockBuilder.invoke("appendNull", BlockBuilder.class).pop()); + } + + body.append(new ForLoop() + .initialize(position.set(constantInt(0))) + .condition(lessThan(position, positionCount)) + .update(incrementVariable(position, (byte) 2)) + .body(new BytecodeBlock() + .append(loadKeyElement) + .append(loadValueElement) + .append(transformedValueElement.set(function.invoke("apply", Object.class, keyElement.cast(Object.class), valueElement.cast(Object.class)).cast(transformedValueJavaType))) + .append(keySqlType.invoke("appendTo", void.class, block, position, blockBuilder)) + .append(writeTransformedValueElement))); + + body.append(mapBlockBuilder + .invoke("closeEntry", BlockBuilder.class) + .pop()); + body.append(constantType(binder, resultMapType) + .invoke( + "getObject", + Object.class, + mapBlockBuilder.cast(Block.class), + subtract(mapBlockBuilder.invoke("getPositionCount", int.class), constantInt(1))) + .ret()); + + Class generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapTransformValueFunction.class.getClassLoader()); + return methodHandle(generatedClass, "transform", Object.class, Block.class, BinaryFunctionInterface.class); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index d7fc2c57baca4..aae3dbe58d0fe 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -75,7 +75,14 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ScalarImplementation implementation = implementations.getExactImplementations().get(boundSignature); Optional methodHandleAndConstructor = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); checkCondition(methodHandleAndConstructor.isPresent(), FUNCTION_IMPLEMENTATION_ERROR, String.format("Exact implementation of %s do not match expected java types.", boundSignature.getName())); - return new ScalarFunctionImplementation(implementation.isNullable(), implementation.getNullableArguments(), implementation.getNullFlags(), methodHandleAndConstructor.get().getMethodHandle(), methodHandleAndConstructor.get().getConstructor(), isDeterministic()); + return new ScalarFunctionImplementation( + implementation.isNullable(), + implementation.getNullableArguments(), + implementation.getNullFlags(), + implementation.getLambdaInterface(), + methodHandleAndConstructor.get().getMethodHandle(), + methodHandleAndConstructor.get().getConstructor(), + isDeterministic()); } ScalarFunctionImplementation selectedImplementation = null; @@ -83,7 +90,14 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); if (methodHandle.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation(implementation.isNullable(), implementation.getNullableArguments(), implementation.getNullFlags(), methodHandle.get().getMethodHandle(), methodHandle.get().getConstructor(), isDeterministic()); + selectedImplementation = new ScalarFunctionImplementation( + implementation.isNullable(), + implementation.getNullableArguments(), + implementation.getNullFlags(), + implementation.getLambdaInterface(), + methodHandle.get().getMethodHandle(), + methodHandle.get().getConstructor(), + isDeterministic()); } } if (selectedImplementation != null) { @@ -94,7 +108,14 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Optional methodHandle = implementation.specialize(boundSignature, boundVariables, typeManager, functionRegistry); if (methodHandle.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", getSignature(), boundVariables.getTypeVariables()); - selectedImplementation = new ScalarFunctionImplementation(implementation.isNullable(), implementation.getNullableArguments(), implementation.getNullFlags(), methodHandle.get().getMethodHandle(), methodHandle.get().getConstructor(), isDeterministic()); + selectedImplementation = new ScalarFunctionImplementation( + implementation.isNullable(), + implementation.getNullableArguments(), + implementation.getNullFlags(), + implementation.getLambdaInterface(), + methodHandle.get().getMethodHandle(), + methodHandle.get().getConstructor(), + isDeterministic()); } } if (selectedImplementation != null) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RepeatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RepeatFunction.java new file mode 100644 index 0000000000000..3d8dcd86abc47 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RepeatFunction.java @@ -0,0 +1,167 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import io.airlift.slice.Slice; + +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static com.facebook.presto.util.Failures.checkCondition; +import static java.lang.Math.toIntExact; + +@ScalarFunction("repeat") +@Description("Repeat an element for a given number of times") +public final class RepeatFunction +{ + private static final long MAX_RESULT_ENTRIES = 10_000; + private static final long MAX_SIZE_IN_BYTES = 1_000_000; + + private RepeatFunction() {} + + @SqlType("array(unknown)") + public static Block repeat( + @SqlNullable @SqlType("unknown") Void element, + @SqlType(StandardTypes.INTEGER) long count) + { + checkCondition(element == null, INVALID_FUNCTION_ARGUMENT, "expect null values"); + BlockBuilder blockBuilder = createBlockBuilder(UNKNOWN, count); + return repeatNullValues(blockBuilder, count); + } + + @TypeParameter("T") + @SqlType("array(T)") + public static Block repeat( + @TypeParameter("T") Type type, + @SqlNullable @SqlType("T") Object element, + @SqlType(StandardTypes.INTEGER) long count) + { + BlockBuilder blockBuilder = createBlockBuilder(type, count); + if (element == null) { + return repeatNullValues(blockBuilder, count); + } + if (count > 0) { + type.writeObject(blockBuilder, element); + checkMaxSize(blockBuilder.getSizeInBytes(), count); + } + for (int i = 1; i < count; i++) { + type.writeObject(blockBuilder, element); + } + return blockBuilder.build(); + } + + @TypeParameter("T") + @SqlType("array(T)") + public static Block repeat( + @TypeParameter("T") Type type, + @SqlNullable @SqlType("T") Long element, + @SqlType(StandardTypes.INTEGER) long count) + { + BlockBuilder blockBuilder = createBlockBuilder(type, count); + if (element == null) { + return repeatNullValues(blockBuilder, count); + } + for (int i = 0; i < count; i++) { + type.writeLong(blockBuilder, element); + } + return blockBuilder.build(); + } + + @TypeParameter("T") + @SqlType("array(T)") + public static Block repeat( + @TypeParameter("T") Type type, + @SqlNullable @SqlType("T") Slice element, + @SqlType(StandardTypes.INTEGER) long count) + { + BlockBuilder blockBuilder = createBlockBuilder(type, count); + if (element == null) { + return repeatNullValues(blockBuilder, count); + } + if (count > 0) { + type.writeSlice(blockBuilder, element); + checkMaxSize(blockBuilder.getSizeInBytes(), count); + } + for (int i = 1; i < count; i++) { + type.writeSlice(blockBuilder, element); + } + return blockBuilder.build(); + } + + @TypeParameter("T") + @SqlType("array(T)") + public static Block repeat( + @TypeParameter("T") Type type, + @SqlNullable @SqlType("T") Boolean element, + @SqlType(StandardTypes.INTEGER) long count) + { + BlockBuilder blockBuilder = createBlockBuilder(type, count); + if (element == null) { + return repeatNullValues(blockBuilder, count); + } + for (int i = 0; i < count; i++) { + type.writeBoolean(blockBuilder, element); + } + return blockBuilder.build(); + } + + @TypeParameter("T") + @SqlType("array(T)") + public static Block repeat( + @TypeParameter("T") Type type, + @SqlNullable @SqlType("T") Double element, + @SqlType(StandardTypes.INTEGER) long count) + { + BlockBuilder blockBuilder = createBlockBuilder(type, count); + if (element == null) { + return repeatNullValues(blockBuilder, count); + } + for (int i = 0; i < count; i++) { + type.writeDouble(blockBuilder, element); + } + return blockBuilder.build(); + } + + private static BlockBuilder createBlockBuilder(Type type, long count) + { + checkCondition(count <= MAX_RESULT_ENTRIES, INVALID_FUNCTION_ARGUMENT, "count argument of repeat function must be less than or equal to 10000"); + checkCondition(count >= 0, INVALID_FUNCTION_ARGUMENT, "count argument of repeat function must be greater than or equal to 0"); + return type.createBlockBuilder(new BlockBuilderStatus(), toIntExact(count)); + } + + private static Block repeatNullValues(BlockBuilder blockBuilder, long count) + { + for (int i = 0; i < count; i++) { + blockBuilder.appendNull(); + } + return blockBuilder.build(); + } + + private static void checkMaxSize(long bytes, long count) + { + checkCondition( + bytes <= (MAX_SIZE_IN_BYTES + count) / count, + INVALID_FUNCTION_ARGUMENT, + "result of repeat function must not take more than 1000000 bytes"); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java index 44dac4602ed83..9835d4bae40d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java @@ -19,9 +19,9 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.RowType; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java index b2b3dc3c67f30..c13be17367080 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java index b7fc55e79b6a7..8aadf24d884cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowGreaterThanOrEqualOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java index 9fa72bdf2d8b1..013df743bf74d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java index 82c7519499b72..4f71aaf7c70ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowLessThanOrEqualOperator.java @@ -16,9 +16,9 @@ import com.facebook.presto.metadata.BoundVariables; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java index f4c8c6b0ec207..f938fa0b1b14d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarFunctionImplementation.java @@ -16,11 +16,13 @@ import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; -import java.util.Collections; import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public final class ScalarFunctionImplementation @@ -28,25 +30,66 @@ public final class ScalarFunctionImplementation private final boolean nullable; private final List nullableArguments; private final List nullFlags; + private final List> lambdaInterface; private final MethodHandle methodHandle; private final Optional instanceFactory; private final boolean deterministic; public ScalarFunctionImplementation(boolean nullable, List nullableArguments, MethodHandle methodHandle, boolean deterministic) { - this(nullable, nullableArguments, Collections.nCopies(nullableArguments.size(), false), methodHandle, Optional.empty(), deterministic); + this( + nullable, + nullableArguments, + nCopies(nullableArguments.size(), false), + nCopies(nullableArguments.size(), Optional.empty()), + methodHandle, + Optional.empty(), + deterministic); } public ScalarFunctionImplementation(boolean nullable, List nullableArguments, List nullFlags, MethodHandle methodHandle, boolean deterministic) { - this(nullable, nullableArguments, nullFlags, methodHandle, Optional.empty(), deterministic); + this( + nullable, + nullableArguments, + nullFlags, + nCopies(nullableArguments.size(), Optional.empty()), + methodHandle, + Optional.empty(), + deterministic); } - public ScalarFunctionImplementation(boolean nullable, List nullableArguments, List nullFlags, MethodHandle methodHandle, Optional instanceFactory, boolean deterministic) + public ScalarFunctionImplementation( + boolean nullable, + List nullableArguments, + List nullFlags, + List> lambdaInterface, + MethodHandle methodHandle, + boolean deterministic) + { + this( + nullable, + nullableArguments, + nullFlags, + lambdaInterface, + methodHandle, + Optional.empty(), + deterministic); + } + + public ScalarFunctionImplementation( + boolean nullable, + List nullableArguments, + List nullFlags, + List> lambdaInterface, + MethodHandle methodHandle, + Optional instanceFactory, + boolean deterministic) { this.nullable = nullable; this.nullableArguments = ImmutableList.copyOf(requireNonNull(nullableArguments, "nullableArguments is null")); this.nullFlags = ImmutableList.copyOf(requireNonNull(nullFlags, "nullFlags is null")); + this.lambdaInterface = ImmutableList.copyOf(requireNonNull(lambdaInterface, "lambdaInterface is null")); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); this.deterministic = deterministic; @@ -56,10 +99,19 @@ public ScalarFunctionImplementation(boolean nullable, List nullableArgu checkArgument(instanceType.equals(methodHandle.type().parameterType(0)), "methodHandle is not an instance method"); } - // check if nullableArguments and nullFlags match + checkCondition(nullFlags.size() == nullableArguments.size(), FUNCTION_IMPLEMENTATION_ERROR, "size of nullFlags is not equal to size of nullableArguments: %s", methodHandle); + checkCondition(nullFlags.size() == lambdaInterface.size(), FUNCTION_IMPLEMENTATION_ERROR, "size of nullFlags is not equal to size of lambdaInterface: %s", methodHandle); + // check if + // - nullableArguments and nullFlags match + // - lambda interface is not nullable + // - lambda interface is annotated with FunctionalInterface for (int i = 0; i < nullFlags.size(); i++) { if (nullFlags.get(i)) { - checkArgument((boolean) nullableArguments.get(i), "argument %s marked as @IsNull is not nullable in method: %s", i, methodHandle); + checkCondition(nullableArguments.get(i), FUNCTION_IMPLEMENTATION_ERROR, "argument %s marked as @IsNull is not nullable in method: %s", i, methodHandle); + } + if (lambdaInterface.get(i).isPresent()) { + checkCondition(!nullableArguments.get(i), FUNCTION_IMPLEMENTATION_ERROR, "argument %s marked as lambda is nullable in method: %s", i, methodHandle); + checkCondition(lambdaInterface.get(i).get().isAnnotationPresent(FunctionalInterface.class), FUNCTION_IMPLEMENTATION_ERROR, "argument %s is marked as lambda but the function interface class is not annotated: %s", i, methodHandle); } } } @@ -79,6 +131,11 @@ public List getNullFlags() return nullFlags; } + public List> getLambdaInterface() + { + return lambdaInterface; + } + public MethodHandle getMethodHandle() { return methodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/SequenceFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/SequenceFunction.java index 29dd4b3c5ab0a..dbfb23d710a49 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/SequenceFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/SequenceFunction.java @@ -35,6 +35,7 @@ public final class SequenceFunction { + private static final long MAX_RESULT_ENTRIES = 10_000; private static final Slice MONTH = Slices.utf8Slice("month"); private SequenceFunction() {} @@ -82,6 +83,7 @@ public static Block sequenceTimestampYearToMonth( "sequence end value should be greater than or equal to start value if step is greater than zero otherwise end should be less than start"); int length = toIntExact(diffTimestamp(session, MONTH, start, end) / step + 1); + checkCondition(length <= MAX_RESULT_ENTRIES, INVALID_FUNCTION_ARGUMENT, "result of sequence function must not have more than 10000 entries"); BlockBuilder blockBuilder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), length); @@ -101,6 +103,7 @@ private static Block fixedWidthSequence(long start, long stop, long step, FixedW "sequence stop value should be greater than or equal to start value if step is greater than zero otherwise stop should be less than start"); int length = toIntExact((stop - start) / step + 1L); + checkCondition(length <= MAX_RESULT_ENTRIES, INVALID_FUNCTION_ARGUMENT, "result of sequence function must not have more than 10000 entries"); BlockBuilder blockBuilder = type.createBlockBuilder(new BlockBuilderStatus(), length); for (long i = 0, value = start; i < length; ++i, value += step) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java index 3bf9484eb01d8..ea9d919a9f652 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java @@ -26,6 +26,7 @@ import io.airlift.slice.XxHash64; import java.util.Base64; +import java.util.zip.CRC32; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -157,6 +158,26 @@ public static long fromBigEndian64(@SqlType(StandardTypes.VARBINARY) Slice slice return Long.reverseBytes(slice.getLong(0)); } + @Description("encode value as a big endian varbinary according to IEEE 754 single-precision floating-point format") + @ScalarFunction("to_ieee754_32") + @SqlType(StandardTypes.VARBINARY) + public static Slice toIEEE754Binary32(@SqlType(StandardTypes.REAL) long value) + { + Slice slice = Slices.allocate(Float.BYTES); + slice.setInt(0, Integer.reverseBytes((int) value)); + return slice; + } + + @Description("encode value as a big endian varbinary according to IEEE 754 double-precision floating-point format") + @ScalarFunction("to_ieee754_64") + @SqlType(StandardTypes.VARBINARY) + public static Slice toIEEE754Binary64(@SqlType(StandardTypes.DOUBLE) double value) + { + Slice slice = Slices.allocate(Double.BYTES); + slice.setLong(0, Long.reverseBytes(Double.doubleToLongBits(value))); + return slice; + } + @Description("compute md5 hash") @ScalarFunction @SqlType(StandardTypes.VARBINARY) @@ -220,4 +241,14 @@ public static Slice fromHexVarbinary(@SqlType(StandardTypes.VARBINARY) Slice sli { return fromHexVarchar(slice); } + + @Description("compute CRC-32") + @ScalarFunction + @SqlType(StandardTypes.BIGINT) + public static long crc32(@SqlType(StandardTypes.VARBINARY) Slice slice) + { + CRC32 crc32 = new CRC32(); + crc32.update(slice.toByteBuffer()); + return crc32.getValue(); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java index 5138cd93df457..36ea2d1f2c9c2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipFunction.java @@ -22,10 +22,10 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java index ecfa421315a9b..5385a222088f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ZipWithFunction.java @@ -23,10 +23,12 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import java.lang.invoke.MethodHandle; +import java.util.Optional; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -41,7 +43,7 @@ public final class ZipWithFunction { public static final ZipWithFunction ZIP_WITH_FUNCTION = new ZipWithFunction(); - private static final MethodHandle METHOD_HANDLE = methodHandle(ZipWithFunction.class, "zipWith", Type.class, Type.class, Type.class, Block.class, Block.class, MethodHandle.class); + private static final MethodHandle METHOD_HANDLE = methodHandle(ZipWithFunction.class, "zipWith", Type.class, Type.class, Type.class, Block.class, Block.class, BinaryFunctionInterface.class); private ZipWithFunction() { @@ -82,11 +84,13 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in return new ScalarFunctionImplementation( false, ImmutableList.of(false, false, false), + ImmutableList.of(false, false, false), + ImmutableList.of(Optional.empty(), Optional.empty(), Optional.of(BinaryFunctionInterface.class)), METHOD_HANDLE.bindTo(leftElementType).bindTo(rightElementType).bindTo(outputElementType), isDeterministic()); } - public static Block zipWith(Type leftElementType, Type rightElementType, Type outputElementType, Block leftBlock, Block rightBlock, MethodHandle function) + public static Block zipWith(Type leftElementType, Type rightElementType, Type outputElementType, Block leftBlock, Block rightBlock, BinaryFunctionInterface function) { checkCondition(leftBlock.getPositionCount() == rightBlock.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "Arrays must have the same length"); BlockBuilder resultBuilder = outputElementType.createBlockBuilder(new BlockBuilderStatus(), leftBlock.getPositionCount()); @@ -95,7 +99,7 @@ public static Block zipWith(Type leftElementType, Type rightElementType, Type ou Object right = readNativeValue(rightElementType, rightBlock, position); Object output; try { - output = function.invoke(left, right); + output = function.apply(left, right); } catch (Throwable throwable) { throw Throwables.propagate(throwable); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java index 28f376fac3167..8c94d24877bbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementation.java @@ -18,8 +18,8 @@ import com.facebook.presto.metadata.LongVariableConstraint; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TypeVariableConstraint; +import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.ConnectorSession; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.FunctionDependency; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; @@ -32,7 +32,9 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.type.Constraint; +import com.facebook.presto.type.FunctionType; import com.facebook.presto.type.LiteralParameter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -77,11 +79,13 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.util.Failures.checkCondition; +import static com.facebook.presto.util.Reflection.constructorMethodHandle; +import static com.facebook.presto.util.Reflection.methodHandle; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.String.format; -import static java.lang.invoke.MethodHandles.lookup; import static java.lang.invoke.MethodHandles.permuteArguments; import static java.lang.reflect.Modifier.isStatic; import static java.util.Arrays.asList; @@ -93,6 +97,7 @@ public class ScalarImplementation private final boolean nullable; private final List nullableArguments; private final List nullFlags; + private final List> lambdaInterface; private final MethodHandle methodHandle; private final List dependencies; private final Optional constructor; @@ -105,6 +110,7 @@ public ScalarImplementation( boolean nullable, List nullableArguments, List nullFlags, + List> lambdaInterface, MethodHandle methodHandle, List dependencies, Optional constructor, @@ -116,6 +122,7 @@ public ScalarImplementation( this.nullable = nullable; this.nullableArguments = ImmutableList.copyOf(requireNonNull(nullableArguments, "nullableArguments is null")); this.nullFlags = ImmutableList.copyOf(requireNonNull(nullFlags, "nullFlags is null")); + this.lambdaInterface = ImmutableList.copyOf(requireNonNull(lambdaInterface, "lambdaInterface is null")); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null")); this.constructor = requireNonNull(constructor, "constructor is null"); @@ -136,11 +143,20 @@ public Optional specialize(Signature boundSignature, return Optional.empty(); } for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) { - Class argumentType = typeManager.getType(boundSignature.getArgumentTypes().get(i)).getJavaType(); - boolean nullableParameter = isParameterNullable(argumentType, nullableArguments.get(i), nullFlags.get(i)); - Class argumentContainerType = getNullAwareContainerType(argumentType, nullableParameter); - if (!argumentNativeContainerTypes.get(i).isAssignableFrom(argumentContainerType)) { - return Optional.empty(); + if (boundSignature.getArgumentTypes().get(i).getBase().equals(FunctionType.NAME)) { + // function does not have a corresponding Java type, an instance of specified interface + // with single abstract method will be generated. + if (!lambdaInterface.get(i).isPresent()) { + return Optional.empty(); + } + } + else { + Class argumentType = typeManager.getType(boundSignature.getArgumentTypes().get(i)).getJavaType(); + boolean nullableParameter = isParameterNullable(argumentType, nullableArguments.get(i), nullFlags.get(i)); + Class argumentContainerType = getNullAwareContainerType(argumentType, nullableParameter); + if (!argumentNativeContainerTypes.get(i).isAssignableFrom(argumentContainerType)) { + return Optional.empty(); + } } } MethodHandle methodHandle = this.methodHandle; @@ -216,6 +232,11 @@ public List getNullFlags() return nullFlags; } + public List> getLambdaInterface() + { + return lambdaInterface; + } + public MethodHandle getMethodHandle() { return methodHandle; @@ -298,7 +319,14 @@ public Signature getSignature() public MethodHandle resolve(BoundVariables boundVariables, TypeManager typeManager, FunctionRegistry functionRegistry) { Signature signature = applyBoundVariables(this.signature, boundVariables, this.signature.getArgumentTypes().size()); - return functionRegistry.getScalarFunctionImplementation(signature).getMethodHandle(); + ScalarFunctionImplementation scalarFunctionImplementation = functionRegistry.getScalarFunctionImplementation(signature); + if (scalarFunctionImplementation.getInstanceFactory().isPresent()) { + // TODO: This feature is useful for a few casts, e.g. MapToMapCast, JsonToMapCast + // Implementing this requires a revamp because we must be able to defer binding of MethodHandles, + // and be able to express such need in a recursive way in ScalarFunctionImplementation. + throw new UnsupportedOperationException("OperatorDependency/FunctionDependency cannot refer to methods with instance factory"); + } + return scalarFunctionImplementation.getMethodHandle(); } } @@ -345,12 +373,14 @@ public static final class Parser private final boolean nullable; private final List nullableArguments = new ArrayList<>(); private final List nullFlags = new ArrayList<>(); + private final List> lambdaInterface = new ArrayList<>(); private final TypeSignature returnType; private final List argumentTypes = new ArrayList<>(); private final List> argumentNativeContainerTypes = new ArrayList<>(); private final MethodHandle methodHandle; private final List dependencies = new ArrayList<>(); private final LinkedHashSet typeParameters = new LinkedHashSet<>(); + private final ImmutableSet typeParameterNames; private final Set literalParameters = new HashSet<>(); private final Map> specializedTypeParameters; private final Optional constructorMethodHandle; @@ -366,6 +396,10 @@ private Parser(String functionName, Method method, Map, Const Stream.of(method.getAnnotationsByType(TypeParameter.class)) .forEach(typeParameters::add); + typeParameterNames = typeParameters.stream() + .map(TypeParameter::value) + .collect(toImmutableSet()); + LiteralParameters literalParametersAnnotation = method.getAnnotation(LiteralParameters.class); if (literalParametersAnnotation != null) { literalParameters.addAll(asList(literalParametersAnnotation.value())); @@ -389,6 +423,12 @@ else if (actualReturnType.isPrimitive()) { this.specializedTypeParameters = getDeclaredSpecializedTypeParameters(method); + for (TypeParameter typeParameter : typeParameters) { + checkArgument( + typeParameter.value().matches("[A-Z][A-Z0-9]*"), + "Expected type parameter to only contain A-Z and 0-9 (starting with A-Z), but got %s on method [%s]", typeParameter.value(), method); + } + parseArguments(method); this.constructorMethodHandle = getConstructor(method, constructors); @@ -398,9 +438,6 @@ else if (actualReturnType.isPrimitive()) { private void parseArguments(Method method) { - ImmutableSet typeParameterNames = typeParameters.stream() - .map(TypeParameter::value) - .collect(toImmutableSet()); for (int i = 0; i < method.getParameterCount(); i++) { Annotation[] annotations = method.getParameterAnnotations()[i]; Class parameterType = method.getParameterTypes()[i]; @@ -413,7 +450,7 @@ private void parseArguments(Method method) checkArgument(argumentTypes.isEmpty(), "Meta parameter must come before parameters [%s]", method); Annotation annotation = annotations[0]; if (annotation instanceof TypeParameter) { - checkArgument(typeParameters.contains(annotation), "Injected type parameters must be declared with @TypeParameter annotation on the method [%s]", method); + checkTypeParameters(parseTypeSignature(((TypeParameter) annotation).value()), method, typeParameterNames); } if (annotation instanceof LiteralParameter) { checkArgument(literalParameters.contains(((LiteralParameter) annotation).value()), "Parameter injected by @LiteralParameter must be declared with @LiteralParameters on the method [%s]", method); @@ -428,6 +465,7 @@ private void parseArguments(Method method) .map(SqlType.class::cast) .findFirst() .orElseThrow(() -> new IllegalArgumentException(format("Method [%s] is missing @SqlType annotation for parameter", method))); + TypeSignature typeSignature = parseTypeSignature(type.value(), literalParameters); boolean nullableArgument = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance); checkArgument(nullableArgument || !containsLegacyNullable(annotations), "Method [%s] has parameter annotated with @Nullable but not @SqlNullable", method); @@ -461,14 +499,38 @@ else if (parameterType.isPrimitive() && !hasNullFlag) { specializedTypeParameters.put(type.value(), nativeParameterType); } argumentNativeContainerTypes.add(parameterType); - argumentTypes.add(parseTypeSignature(type.value(), literalParameters)); + argumentTypes.add(typeSignature); if (hasNullFlag) { // skip @IsNull parameter i++; } + nullableArguments.add(nullableArgument); nullFlags.add(hasNullFlag); + if (typeSignature.getBase().equals(FunctionType.NAME)) { + checkCondition(parameterType.isAnnotationPresent(FunctionalInterface.class), FUNCTION_IMPLEMENTATION_ERROR, "argument %s is marked as lambda but the function interface class is not annotated: %s", i, methodHandle); + lambdaInterface.add(Optional.of(parameterType)); + } + else { + lambdaInterface.add(Optional.empty()); + } + } + } + } + + private void checkTypeParameters(TypeSignature typeSignature, Method method, Set typeParameterNames) + { + // Check recursively if `typeSignature` contains something like `T` + if (typeParameterNames.contains(typeSignature.getBase())) { + checkArgument(typeSignature.getParameters().isEmpty(), "Expected type parameter not to take parameters, but got %s on method [%s]", typeSignature.getBase(), method); + return; + } + + for (TypeSignatureParameter parameter : typeSignature.getParameters()) { + Optional childTypeSignature = parameter.getTypeSignatureOrNamedTypeSignature(); + if (childTypeSignature.isPresent()) { + checkTypeParameters(childTypeSignature.get(), method, typeParameterNames); } } } @@ -488,16 +550,11 @@ private Optional getConstructor(Method method, Map> getDeclaredSpecializedTypeParameters(Method method) @@ -519,13 +576,7 @@ private Map> getDeclaredSpecializedTypeParameters(Method method private MethodHandle getMethodHandle(Method method) { - MethodHandle methodHandle; - try { - methodHandle = lookup().unreflect(method); - } - catch (IllegalAccessException e) { - throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, e); - } + MethodHandle methodHandle = methodHandle(FUNCTION_IMPLEMENTATION_ERROR, method); if (!isStatic(method.getModifiers())) { // Re-arrange the parameters, so that the "this" parameter is after the meta parameters int[] permutedIndices = new int[methodHandle.type().parameterCount()]; @@ -644,6 +695,7 @@ public ScalarImplementation get() nullable, nullableArguments, nullFlags, + lambdaInterface, methodHandle, dependencies, constructorMethodHandle, diff --git a/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java index d1382eb7c6b2f..42768dae4413b 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AccessControl.java @@ -36,6 +36,11 @@ public interface AccessControl */ Set filterCatalogs(Identity identity, Set catalogs); + /** + * Check whether identity is allowed to access catalog + */ + void checkCanAccessCatalog(Identity identity, String catalogName); + /** * Check if identity is allowed to create the specified schema. * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed @@ -165,16 +170,16 @@ public interface AccessControl void checkCanCreateViewWithSelectFromView(TransactionId transactionId, Identity identity, QualifiedObjectName viewName); /** - * Check if identity is allowed to grant a privilege on the specified table. + * Check if identity is allowed to grant a privilege to the grantee on the specified table. * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed */ - void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName); + void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String grantee, boolean withGrantOption); /** - * Check if identity is allowed to revoke a privilege on the specified table. + * Check if identity is allowed to revoke a privilege from the revokee on the specified table. * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed */ - void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName); + void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String revokee, boolean grantOptionFor); /** * Check if identity is allowed to set the specified system property. diff --git a/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java b/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java index 93e76d67a734c..e45681386c6c8 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AccessControlManager.java @@ -28,6 +28,7 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.log.Logger; import io.airlift.stats.CounterStat; import org.weakref.jmx.Managed; @@ -79,6 +80,7 @@ public AccessControlManager(TransactionManager transactionManager) this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); addSystemAccessControlFactory(new AllowAllSystemAccessControl.Factory()); addSystemAccessControlFactory(new ReadOnlySystemAccessControl.Factory()); + addSystemAccessControlFactory(new FileBasedSystemAccessControl.Factory()); } public void addSystemAccessControlFactory(SystemAccessControlFactory accessControlFactory) @@ -156,12 +158,23 @@ public Set filterCatalogs(Identity identity, Set catalogs) return systemAccessControl.get().filterCatalogs(identity, catalogs); } + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + requireNonNull(identity, "identity is null"); + requireNonNull(catalogName, "catalog is null"); + + authenticationCheck(() -> systemAccessControl.get().checkCanAccessCatalog(identity, catalogName)); + } + @Override public void checkCanCreateSchema(TransactionId transactionId, Identity identity, CatalogSchemaName schemaName) { requireNonNull(identity, "identity is null"); requireNonNull(schemaName, "schemaName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, schemaName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanCreateSchema(identity, schemaName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schemaName.getCatalogName()); @@ -176,6 +189,8 @@ public void checkCanDropSchema(TransactionId transactionId, Identity identity, C requireNonNull(identity, "identity is null"); requireNonNull(schemaName, "schemaName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, schemaName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanDropSchema(identity, schemaName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schemaName.getCatalogName()); @@ -190,6 +205,8 @@ public void checkCanRenameSchema(TransactionId transactionId, Identity identity, requireNonNull(identity, "identity is null"); requireNonNull(schemaName, "schemaName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, schemaName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanRenameSchema(identity, schemaName, newSchemaName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schemaName.getCatalogName()); @@ -204,6 +221,8 @@ public void checkCanShowSchemas(TransactionId transactionId, Identity identity, requireNonNull(identity, "identity is null"); requireNonNull(catalogName, "catalogName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, catalogName)); + authorizationCheck(() -> systemAccessControl.get().checkCanShowSchemas(identity, catalogName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); @@ -219,6 +238,10 @@ public Set filterSchemas(TransactionId transactionId, Identity identity, requireNonNull(catalogName, "catalogName is null"); requireNonNull(schemaNames, "schemaNames is null"); + if (filterCatalogs(identity, ImmutableSet.of(catalogName)).isEmpty()) { + return ImmutableSet.of(); + } + schemaNames = systemAccessControl.get().filterSchemas(identity, catalogName, schemaNames); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); @@ -234,6 +257,8 @@ public void checkCanCreateTable(TransactionId transactionId, Identity identity, requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanCreateTable(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -248,6 +273,8 @@ public void checkCanDropTable(TransactionId transactionId, Identity identity, Qu requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanDropTable(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -263,6 +290,8 @@ public void checkCanRenameTable(TransactionId transactionId, Identity identity, requireNonNull(tableName, "tableName is null"); requireNonNull(newTableName, "newTableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanRenameTable(identity, tableName.asCatalogSchemaTableName(), newTableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -277,6 +306,8 @@ public void checkCanShowTablesMetadata(TransactionId transactionId, Identity ide requireNonNull(identity, "identity is null"); requireNonNull(schema, "schema is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, schema.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanShowTablesMetadata(identity, schema)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schema.getCatalogName()); @@ -292,6 +323,10 @@ public Set filterTables(TransactionId transactionId, Identity i requireNonNull(catalogName, "catalogName is null"); requireNonNull(tableNames, "tableNames is null"); + if (filterCatalogs(identity, ImmutableSet.of(catalogName)).isEmpty()) { + return ImmutableSet.of(); + } + tableNames = systemAccessControl.get().filterTables(identity, catalogName, tableNames); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); @@ -307,6 +342,8 @@ public void checkCanAddColumns(TransactionId transactionId, Identity identity, Q requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanAddColumn(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -321,6 +358,8 @@ public void checkCanRenameColumn(TransactionId transactionId, Identity identity, requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanRenameColumn(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -335,6 +374,8 @@ public void checkCanSelectFromTable(TransactionId transactionId, Identity identi requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanSelectFromTable(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -349,6 +390,8 @@ public void checkCanInsertIntoTable(TransactionId transactionId, Identity identi requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanInsertIntoTable(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -363,6 +406,8 @@ public void checkCanDeleteFromTable(TransactionId transactionId, Identity identi requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanDeleteFromTable(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -377,6 +422,8 @@ public void checkCanCreateView(TransactionId transactionId, Identity identity, Q requireNonNull(identity, "identity is null"); requireNonNull(viewName, "viewName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, viewName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanCreateView(identity, viewName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, viewName.getCatalogName()); @@ -391,6 +438,8 @@ public void checkCanDropView(TransactionId transactionId, Identity identity, Qua requireNonNull(identity, "identity is null"); requireNonNull(viewName, "viewName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, viewName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanDropView(identity, viewName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, viewName.getCatalogName()); @@ -405,6 +454,8 @@ public void checkCanSelectFromView(TransactionId transactionId, Identity identit requireNonNull(identity, "identity is null"); requireNonNull(viewName, "viewName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, viewName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanSelectFromView(identity, viewName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, viewName.getCatalogName()); @@ -419,6 +470,8 @@ public void checkCanCreateViewWithSelectFromTable(TransactionId transactionId, I requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanCreateViewWithSelectFromTable(identity, tableName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); @@ -433,6 +486,8 @@ public void checkCanCreateViewWithSelectFromView(TransactionId transactionId, Id requireNonNull(identity, "identity is null"); requireNonNull(viewName, "viewName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, viewName.getCatalogName())); + authorizationCheck(() -> systemAccessControl.get().checkCanCreateViewWithSelectFromView(identity, viewName.asCatalogSchemaTableName())); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, viewName.getCatalogName()); @@ -442,32 +497,36 @@ public void checkCanCreateViewWithSelectFromView(TransactionId transactionId, Id } @Override - public void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName) + public void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String grantee, boolean withGrantOption) { requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); requireNonNull(privilege, "privilege is null"); - authorizationCheck(() -> systemAccessControl.get().checkCanGrantTablePrivilege(identity, privilege, tableName.asCatalogSchemaTableName())); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.get().checkCanGrantTablePrivilege(identity, privilege, tableName.asCatalogSchemaTableName(), grantee, withGrantOption)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { - authorizationCheck(() -> entry.getAccessControl().checkCanGrantTablePrivilege(entry.getTransactionHandle(transactionId), identity, privilege, tableName.asSchemaTableName())); + authorizationCheck(() -> entry.getAccessControl().checkCanGrantTablePrivilege(entry.getTransactionHandle(transactionId), identity, privilege, tableName.asSchemaTableName(), grantee, withGrantOption)); } } @Override - public void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName) + public void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String revokee, boolean grantOptionFor) { requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); requireNonNull(privilege, "privilege is null"); - authorizationCheck(() -> systemAccessControl.get().checkCanRevokeTablePrivilege(identity, privilege, tableName.asCatalogSchemaTableName())); + authenticationCheck(() -> checkCanAccessCatalog(identity, tableName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.get().checkCanRevokeTablePrivilege(identity, privilege, tableName.asCatalogSchemaTableName(), revokee, grantOptionFor)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { - authorizationCheck(() -> entry.getAccessControl().checkCanRevokeTablePrivilege(entry.getTransactionHandle(transactionId), identity, privilege, tableName.asSchemaTableName())); + authorizationCheck(() -> entry.getAccessControl().checkCanRevokeTablePrivilege(entry.getTransactionHandle(transactionId), identity, privilege, tableName.asSchemaTableName(), revokee, grantOptionFor)); } } @@ -487,6 +546,8 @@ public void checkCanSetCatalogSessionProperty(TransactionId transactionId, Ident requireNonNull(catalogName, "catalogName is null"); requireNonNull(propertyName, "propertyName is null"); + authenticationCheck(() -> checkCanAccessCatalog(identity, catalogName)); + authorizationCheck(() -> systemAccessControl.get().checkCanSetCatalogSessionProperty(identity, catalogName, propertyName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); @@ -607,5 +668,11 @@ public void checkCanSetSystemSessionProperty(Identity identity, String propertyN { throw new PrestoException(SERVER_STARTING_UP, "Presto server is still initializing"); } + + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + throw new PrestoException(SERVER_STARTING_UP, "Presto server is still initializing"); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java index 7b65bb1e2c7a8..356ff83c75b88 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AllowAllAccessControl.java @@ -37,6 +37,11 @@ public Set filterCatalogs(Identity identity, Set catalogs) return catalogs; } + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + } + @Override public void checkCanCreateSchema(TransactionId transactionId, Identity identity, CatalogSchemaName schemaName) { @@ -140,12 +145,12 @@ public void checkCanCreateViewWithSelectFromView(TransactionId transactionId, Id } @Override - public void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName) + public void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String grantee, boolean withGrantOption) { } @Override - public void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName) + public void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String revokee, boolean grantOptionFor) { } diff --git a/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java index 6a48fbfa53790..442718d4e337f 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java @@ -63,6 +63,11 @@ public void checkCanSetSystemSessionProperty(Identity identity, String propertyN { } + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + } + @Override public Set filterCatalogs(Identity identity, Set catalogs) { @@ -177,12 +182,12 @@ public void checkCanSetCatalogSessionProperty(Identity identity, String catalogN } @Override - public void checkCanGrantTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table) + public void checkCanGrantTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table, String grantee, boolean withGrantOption) { } @Override - public void checkCanRevokeTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table) + public void checkCanRevokeTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table, String revokee, boolean grantOptionFor) { } } diff --git a/presto-main/src/main/java/com/facebook/presto/security/CatalogAccessControlRule.java b/presto-main/src/main/java/com/facebook/presto/security/CatalogAccessControlRule.java new file mode 100644 index 0000000000000..960bbbcdabc5a --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/security/CatalogAccessControlRule.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Optional; +import java.util.regex.Pattern; + +import static java.util.Objects.requireNonNull; + +public class CatalogAccessControlRule +{ + private final boolean allow; + private final Optional userRegex; + private final Optional catalogRegex; + + @JsonCreator + public CatalogAccessControlRule( + @JsonProperty("allow") boolean allow, + @JsonProperty("user") Optional userRegex, + @JsonProperty("catalog") Optional catalogRegex) + { + this.allow = allow; + this.userRegex = requireNonNull(userRegex, "userRegex is null"); + this.catalogRegex = requireNonNull(catalogRegex, "catalogRegex is null"); + } + + public Optional match(String user, String catalog) + { + if (userRegex.map(regex -> regex.matcher(user).matches()).orElse(true) && + catalogRegex.map(regex -> regex.matcher(catalog).matches()).orElse(true)) { + return Optional.of(allow); + } + return Optional.empty(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java index 0cb2959d2465b..1195c1d94d2a2 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/DenyAllAccessControl.java @@ -25,6 +25,7 @@ import java.util.Set; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddColumn; +import static com.facebook.presto.spi.security.AccessDeniedException.denyCatalogAccess; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; @@ -62,6 +63,12 @@ public Set filterCatalogs(Identity identity, Set catalogs) return ImmutableSet.of(); } + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + denyCatalogAccess(catalogName); + } + @Override public void checkCanCreateSchema(TransactionId transactionId, Identity identity, CatalogSchemaName schemaName) { @@ -183,13 +190,13 @@ public void checkCanCreateViewWithSelectFromView(TransactionId transactionId, Id } @Override - public void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName) + public void checkCanGrantTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String grantee, boolean withGrantOption) { denyGrantTablePrivilege(privilege.name(), tableName.toString()); } @Override - public void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName) + public void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity identity, Privilege privilege, QualifiedObjectName tableName, String revokee, boolean grantOptionFor) { denyRevokeTablePrivilege(privilege.name(), tableName.toString()); } diff --git a/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java new file mode 100644 index 0000000000000..1490f6ec74cc4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java @@ -0,0 +1,268 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.facebook.presto.spi.CatalogSchemaName; +import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.security.SystemAccessControl; +import com.facebook.presto.spi.security.SystemAccessControlFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.InvalidPathException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.Principal; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; + +import static com.facebook.presto.spi.security.AccessDeniedException.denyCatalogAccess; +import static com.google.common.base.Preconditions.checkState; +import static io.airlift.json.JsonCodec.jsonCodec; +import static java.util.Objects.requireNonNull; + +public class FileBasedSystemAccessControl + implements SystemAccessControl +{ + public static final String NAME = "file"; + + private final List catalogRules; + + private FileBasedSystemAccessControl(List catalogRules) + { + this.catalogRules = catalogRules; + } + + public static class Factory + implements SystemAccessControlFactory + { + private static final String CONFIG_FILE_NAME = "security.config-file"; + + @Override + public String getName() + { + return NAME; + } + + @Override + public SystemAccessControl create(Map config) + { + requireNonNull(config, "config is null"); + + String configFileName = config.get(CONFIG_FILE_NAME); + checkState( + configFileName != null, + "Security configuration must contain the '%s' property", CONFIG_FILE_NAME); + + try { + Path path = Paths.get(configFileName); + if (!path.isAbsolute()) { + path = path.toAbsolutePath(); + } + path.toFile().canRead(); + + ImmutableList.Builder catalogRulesBuilder = ImmutableList.builder(); + catalogRulesBuilder.addAll(jsonCodec(FileBasedSystemAccessControlRules.class) + .fromJson(Files.readAllBytes(path)) + .getCatalogRules()); + + // Hack to allow Presto Admin to access the "system" catalog for retrieving server status. + // todo Change userRegex from ".*" to one particular user that Presto Admin will be restricted to run as + catalogRulesBuilder.add(new CatalogAccessControlRule( + true, + Optional.of(Pattern.compile(".*")), + Optional.of(Pattern.compile("system")))); + + return new FileBasedSystemAccessControl(catalogRulesBuilder.build()); + } + catch (SecurityException | IOException | InvalidPathException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public void checkCanSetUser(Principal principal, String userName) + { + } + + @Override + public void checkCanSetSystemSessionProperty(Identity identity, String propertyName) + { + } + + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + if (!canAccessCatalog(identity, catalogName)) { + denyCatalogAccess(catalogName); + } + } + + @Override + public Set filterCatalogs(Identity identity, Set catalogs) + { + ImmutableSet.Builder filteredCatalogs = ImmutableSet.builder(); + for (String catalog : catalogs) { + if (canAccessCatalog(identity, catalog)) { + filteredCatalogs.add(catalog); + } + } + return filteredCatalogs.build(); + } + + private boolean canAccessCatalog(Identity identity, String catalogName) + { + for (CatalogAccessControlRule rule : catalogRules) { + Optional allowed = rule.match(identity.getUser(), catalogName); + if (allowed.isPresent()) { + return allowed.get(); + } + } + return false; + } + + @Override + public void checkCanCreateSchema(Identity identity, CatalogSchemaName schema) + { + } + + @Override + public void checkCanDropSchema(Identity identity, CatalogSchemaName schema) + { + } + + @Override + public void checkCanRenameSchema(Identity identity, CatalogSchemaName schema, String newSchemaName) + { + } + + @Override + public void checkCanShowSchemas(Identity identity, String catalogName) + { + } + + @Override + public Set filterSchemas(Identity identity, String catalogName, Set schemaNames) + { + if (!canAccessCatalog(identity, catalogName)) { + return ImmutableSet.of(); + } + + return schemaNames; + } + + @Override + public void checkCanCreateTable(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanDropTable(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanRenameTable(Identity identity, CatalogSchemaTableName table, CatalogSchemaTableName newTable) + { + } + + @Override + public void checkCanShowTablesMetadata(Identity identity, CatalogSchemaName schema) + { + } + + @Override + public Set filterTables(Identity identity, String catalogName, Set tableNames) + { + if (!canAccessCatalog(identity, catalogName)) { + return ImmutableSet.of(); + } + + return tableNames; + } + + @Override + public void checkCanAddColumn(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanRenameColumn(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanSelectFromTable(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanInsertIntoTable(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanDeleteFromTable(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanCreateView(Identity identity, CatalogSchemaTableName view) + { + } + + @Override + public void checkCanDropView(Identity identity, CatalogSchemaTableName view) + { + } + + @Override + public void checkCanSelectFromView(Identity identity, CatalogSchemaTableName view) + { + } + + @Override + public void checkCanCreateViewWithSelectFromTable(Identity identity, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanCreateViewWithSelectFromView(Identity identity, CatalogSchemaTableName view) + { + } + + @Override + public void checkCanSetCatalogSessionProperty(Identity identity, String catalogName, String propertyName) + { + } + + @Override + public void checkCanGrantTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table, String grantee, boolean withGrantOption) + { + } + + @Override + public void checkCanRevokeTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table, String revokee, boolean grantOptionFor) + { + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControlRules.java b/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControlRules.java new file mode 100644 index 0000000000000..5af6e1b25e675 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControlRules.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +public class FileBasedSystemAccessControlRules +{ + private final List catalogRules; + + @JsonCreator + public FileBasedSystemAccessControlRules(@JsonProperty("catalogs") Optional> catalogRules) + { + this.catalogRules = catalogRules.map(ImmutableList::copyOf).orElse(ImmutableList.of()); + } + + public List getCatalogRules() + { + return catalogRules; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java b/presto-main/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java index 6c83b156bdc45..387ddb0379696 100644 --- a/presto-main/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java +++ b/presto-main/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java @@ -62,6 +62,11 @@ public void checkCanSetSystemSessionProperty(Identity identity, String propertyN { } + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + } + @Override public void checkCanSelectFromTable(Identity identity, CatalogSchemaTableName table) { diff --git a/presto-main/src/main/java/com/facebook/presto/server/BasicQueryStats.java b/presto-main/src/main/java/com/facebook/presto/server/BasicQueryStats.java index 660ffe1e83cbf..8f8e3798d636b 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/BasicQueryStats.java +++ b/presto-main/src/main/java/com/facebook/presto/server/BasicQueryStats.java @@ -23,6 +23,7 @@ import javax.annotation.concurrent.Immutable; +import java.util.OptionalDouble; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; @@ -54,6 +55,8 @@ public class BasicQueryStats private final boolean fullyBlocked; private final Set blockedReasons; + private final OptionalDouble progressPercentage; + public BasicQueryStats( DateTime createTime, DateTime endTime, @@ -68,7 +71,8 @@ public BasicQueryStats( DataSize peakMemoryReservation, Duration totalCpuTime, boolean fullyBlocked, - Set blockedReasons) + Set blockedReasons, + OptionalDouble progressPercentage) { this.createTime = createTime; this.endTime = endTime; @@ -92,6 +96,8 @@ public BasicQueryStats( this.fullyBlocked = fullyBlocked; this.blockedReasons = ImmutableSet.copyOf(requireNonNull(blockedReasons, "blockedReasons is null")); + + this.progressPercentage = requireNonNull(progressPercentage, "progressPercentage is null"); } public BasicQueryStats(QueryStats queryStats) @@ -109,7 +115,8 @@ public BasicQueryStats(QueryStats queryStats) queryStats.getPeakMemoryReservation(), queryStats.getTotalCpuTime(), queryStats.isFullyBlocked(), - queryStats.getBlockedReasons()); + queryStats.getBlockedReasons(), + queryStats.getProgressPercentage()); } @JsonProperty @@ -195,4 +202,10 @@ public Set getBlockedReasons() { return blockedReasons; } + + @JsonProperty + public OptionalDouble getProgressPercentage() + { + return progressPercentage; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java b/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java index 1ec0c953508e6..c4bbd3a1d0ace 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java @@ -62,9 +62,9 @@ public ClusterStats getClusterStats() long runningDrivers = 0; double memoryReservation = 0; - double rowInputRate = 0; - double byteInputRate = 0; - double cpuTimeRate = 0; + long totalInputRows = queryManager.getStats().getConsumedInputRows().getTotalCount(); + long totalInputBytes = queryManager.getStats().getConsumedInputBytes().getTotalCount(); + long totalCpuTimeSecs = queryManager.getStats().getConsumedCpuTimeSecs().getTotalCount(); for (QueryInfo query : queryManager.getAllQueryInfo()) { if (query.getState() == QueryState.QUEUED) { @@ -80,18 +80,16 @@ else if (query.getState() == QueryState.RUNNING) { } if (!query.getState().isDone()) { - double totalExecutionTimeSeconds = query.getQueryStats().getElapsedTime().getValue(SECONDS); - if (totalExecutionTimeSeconds != 0) { - byteInputRate += query.getQueryStats().getProcessedInputDataSize().toBytes() / totalExecutionTimeSeconds; - rowInputRate += query.getQueryStats().getProcessedInputPositions() / totalExecutionTimeSeconds; - cpuTimeRate += (query.getQueryStats().getTotalCpuTime().getValue(SECONDS)) / totalExecutionTimeSeconds; - } + totalInputBytes += query.getQueryStats().getRawInputDataSize().toBytes(); + totalInputRows += query.getQueryStats().getRawInputPositions(); + totalCpuTimeSecs += query.getQueryStats().getTotalCpuTime().getValue(SECONDS); + memoryReservation += query.getQueryStats().getTotalMemoryReservation().toBytes(); runningDrivers += query.getQueryStats().getRunningDrivers(); } } - return new ClusterStats(runningQueries, blockedQueries, queuedQueries, activeNodes, runningDrivers, memoryReservation, rowInputRate, byteInputRate, cpuTimeRate); + return new ClusterStats(runningQueries, blockedQueries, queuedQueries, activeNodes, runningDrivers, memoryReservation, totalInputRows, totalInputBytes, totalCpuTimeSecs); } public static class ClusterStats @@ -104,9 +102,9 @@ public static class ClusterStats private final long runningDrivers; private final double reservedMemory; - private final double rowInputRate; - private final double byteInputRate; - private final double cpuTimeRate; + private final long totalInputRows; + private final long totalInputBytes; + private final long totalCpuTimeSecs; @JsonCreator public ClusterStats( @@ -116,9 +114,9 @@ public ClusterStats( @JsonProperty("activeWorkers") long activeWorkers, @JsonProperty("runningDrivers") long runningDrivers, @JsonProperty("reservedMemory") double reservedMemory, - @JsonProperty("rowInputRate") double rowInputRate, - @JsonProperty("byteInputRate") double byteInputRate, - @JsonProperty("cpuTimeRate") double cpuTimeRate) + @JsonProperty("totalInputRows") long totalInputRows, + @JsonProperty("totalInputBytes") long totalInputBytes, + @JsonProperty("totalCpuTimeSecs") long totalCpuTimeSecs) { this.runningQueries = runningQueries; this.blockedQueries = blockedQueries; @@ -126,9 +124,9 @@ public ClusterStats( this.activeWorkers = activeWorkers; this.runningDrivers = runningDrivers; this.reservedMemory = reservedMemory; - this.rowInputRate = rowInputRate; - this.byteInputRate = byteInputRate; - this.cpuTimeRate = cpuTimeRate; + this.totalInputRows = totalInputRows; + this.totalInputBytes = totalInputBytes; + this.totalCpuTimeSecs = totalCpuTimeSecs; } @JsonProperty @@ -168,21 +166,21 @@ public double getReservedMemory() } @JsonProperty - public double getRowInputRate() + public long getTotalInputRows() { - return rowInputRate; + return totalInputRows; } @JsonProperty - public double getByteInputRate() + public long getTotalInputBytes() { - return byteInputRate; + return totalInputBytes; } @JsonProperty - public double getCpuTimeRate() + public long getTotalCpuTimeSecs() { - return cpuTimeRate; + return totalCpuTimeSecs; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java b/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java index 53a4cc643e637..c59c381e68e92 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java @@ -96,6 +96,7 @@ import com.facebook.presto.sql.tree.ShowPartitions; import com.facebook.presto.sql.tree.ShowSchemas; import com.facebook.presto.sql.tree.ShowSession; +import com.facebook.presto.sql.tree.ShowStats; import com.facebook.presto.sql.tree.ShowTables; import com.facebook.presto.sql.tree.StartTransaction; import com.facebook.presto.sql.tree.Statement; @@ -107,6 +108,9 @@ import io.airlift.configuration.AbstractConfigurationAwareModule; import io.airlift.units.Duration; +import javax.annotation.PreDestroy; +import javax.inject.Inject; + import java.util.List; import java.util.concurrent.ExecutorService; @@ -120,6 +124,7 @@ import static io.airlift.http.server.HttpServerBinder.httpServerBinder; import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.TimeUnit.SECONDS; import static org.weakref.jmx.ObjectNames.generatedNameOf; @@ -149,6 +154,7 @@ protected void setup(Binder binder) jaxrsBinder(binder).bind(QueryResource.class); jaxrsBinder(binder).bind(StageResource.class); jaxrsBinder(binder).bind(QueryStateInfoResource.class); + jaxrsBinder(binder).bind(ResourceGroupStateInfoResource.class); binder.bind(QueryIdGenerator.class).in(Scopes.SINGLETON); binder.bind(QueryManager.class).to(SqlQueryManager.class).in(Scopes.SINGLETON); binder.bind(InternalResourceGroupManager.class).in(Scopes.SINGLETON); @@ -211,6 +217,7 @@ protected void setup(Binder binder) executionBinder.addBinding(Explain.class).to(SqlQueryExecutionFactory.class).in(Scopes.SINGLETON); executionBinder.addBinding(ShowCreate.class).to(SqlQueryExecutionFactory.class).in(Scopes.SINGLETON); executionBinder.addBinding(ShowColumns.class).to(SqlQueryExecutionFactory.class).in(Scopes.SINGLETON); + executionBinder.addBinding(ShowStats.class).to(SqlQueryExecutionFactory.class).in(Scopes.SINGLETON); executionBinder.addBinding(ShowPartitions.class).to(SqlQueryExecutionFactory.class).in(Scopes.SINGLETON); executionBinder.addBinding(ShowFunctions.class).to(SqlQueryExecutionFactory.class).in(Scopes.SINGLETON); executionBinder.addBinding(ShowTables.class).to(SqlQueryExecutionFactory.class).in(Scopes.SINGLETON); @@ -250,6 +257,9 @@ protected void setup(Binder binder) MapBinder executionPolicyBinder = newMapBinder(binder, String.class, ExecutionPolicy.class); executionPolicyBinder.addBinding("all-at-once").to(AllAtOnceExecutionPolicy.class); executionPolicyBinder.addBinding("phased").to(PhasedExecutionPolicy.class); + + // cleanup + binder.bind(ExecutorCleanup.class).in(Scopes.SINGLETON); } private static void bindDataDefinitionTask( @@ -264,4 +274,21 @@ private static void bindDataDefinitionTask( taskBinder.addBinding(statement).to(task).in(Scopes.SINGLETON); executionBinder.addBinding(statement).to(DataDefinitionExecutionFactory.class).in(Scopes.SINGLETON); } + + public static class ExecutorCleanup + { + private final ExecutorService executor; + + @Inject + public ExecutorCleanup(@ForQueryExecution ExecutorService executor) + { + this.executor = requireNonNull(executor, "executor is null"); + } + + @PreDestroy + public void shutdown() + { + executor.shutdownNow(); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java b/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java new file mode 100644 index 0000000000000..c471f2cd64454 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import io.airlift.configuration.Config; + +public class InternalCommunicationConfig +{ + private boolean httpsRequired; + private String keyStorePath; + private String keyStorePassword; + + public boolean isHttpsRequired() + { + return httpsRequired; + } + + @Config("internal-communication.https.required") + public InternalCommunicationConfig setHttpsRequired(boolean httpsRequired) + { + this.httpsRequired = httpsRequired; + return this; + } + + public String getKeyStorePath() + { + return keyStorePath; + } + + @Config("internal-communication.https.keystore.path") + public InternalCommunicationConfig setKeyStorePath(String keyStorePath) + { + this.keyStorePath = keyStorePath; + return this; + } + + public String getKeyStorePassword() + { + return keyStorePassword; + } + + @Config("internal-communication.https.keystore.key") + public InternalCommunicationConfig setKeyStorePassword(String keyStorePassword) + { + this.keyStorePassword = keyStorePassword; + return this; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfo.java b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfo.java new file mode 100644 index 0000000000000..0b3af0371c742 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfo.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.spi.resourceGroups.ResourceGroupState; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ResourceGroupStateInfo +{ + private final ResourceGroupId id; + private final ResourceGroupState state; + + private final DataSize softMemoryLimit; + private final DataSize memoryUsage; + + private final int maxRunningQueries; + private final int maxQueuedQueries; + private final Duration runningTimeLimit; + private final Duration queuedTimeLimit; + private final List runningQueries; + private final int numQueuedQueries; + + @JsonCreator + public ResourceGroupStateInfo( + @JsonProperty("id") ResourceGroupId id, + @JsonProperty("state") ResourceGroupState state, + @JsonProperty("softMemoryLimit") DataSize softMemoryLimit, + @JsonProperty("memoryUsage") DataSize memoryUsage, + @JsonProperty("maxRunningQueries") int maxRunningQueries, + @JsonProperty("maxQueuedQueries") int maxQueuedQueries, + @JsonProperty("runningTimeLimit") Duration runningTimeLimit, + @JsonProperty("queuedTimeLimit") Duration queuedTimeLimit, + @JsonProperty("runningQueries") List runningQueries, + @JsonProperty("numQueuedQueries") int numQueuedQueries) + { + this.id = requireNonNull(id, "id is null"); + this.state = requireNonNull(state, "state is null"); + + this.softMemoryLimit = requireNonNull(softMemoryLimit, "softMemoryLimit is null"); + this.memoryUsage = requireNonNull(memoryUsage, "memoryUsage is null"); + + this.maxRunningQueries = maxRunningQueries; + this.maxQueuedQueries = maxQueuedQueries; + + this.runningTimeLimit = requireNonNull(runningTimeLimit, "runningTimeLimit is null"); + this.queuedTimeLimit = requireNonNull(queuedTimeLimit, "queuedTimeLimit is null"); + + this.runningQueries = ImmutableList.copyOf(requireNonNull(runningQueries, "runningQueries is null")); + this.numQueuedQueries = numQueuedQueries; + } + + @JsonProperty + public ResourceGroupId getId() + { + return id; + } + + @JsonProperty + public ResourceGroupState getState() + { + return state; + } + + @JsonProperty + public DataSize getSoftMemoryLimit() + { + return softMemoryLimit; + } + + @JsonProperty + public DataSize getMemoryUsage() + { + return memoryUsage; + } + + @JsonProperty + public int getMaxRunningQueries() + { + return maxRunningQueries; + } + + @JsonProperty + public int getMaxQueuedQueries() + { + return maxQueuedQueries; + } + + @JsonProperty + public Duration getQueuedTimeLimit() + { + return queuedTimeLimit; + } + + @JsonProperty + public Duration getRunningTimeLimit() + { + return runningTimeLimit; + } + + @JsonProperty + public List getRunningQueries() + { + return runningQueries; + } + + @JsonProperty + public int getNumQueuedQueries() + { + return numQueuedQueries; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java new file mode 100644 index 0000000000000..1b649d367deb4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.presto.execution.resourceGroups.ResourceGroupManager; +import com.facebook.presto.spi.resourceGroups.ResourceGroupId; + +import javax.inject.Inject; +import javax.ws.rs.Encoded; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; + +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.util.Arrays; +import java.util.NoSuchElementException; + +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static javax.ws.rs.core.Response.Status.BAD_REQUEST; +import static javax.ws.rs.core.Response.Status.NOT_FOUND; + +@Path("/v1/resourceGroupState") +public class ResourceGroupStateInfoResource +{ + private final ResourceGroupManager resourceGroupManager; + + @Inject + public ResourceGroupStateInfoResource(ResourceGroupManager resourceGroupManager) + { + this.resourceGroupManager = requireNonNull(resourceGroupManager, "resourceGroupManager is null"); + } + + @GET + @Produces(MediaType.APPLICATION_JSON) + @Encoded + @Path("{resourceGroupId: .+}") + public ResourceGroupStateInfo getQueryStateInfos(@PathParam("resourceGroupId") String resourceGroupIdString) + { + if (!isNullOrEmpty(resourceGroupIdString)) { + try { + return resourceGroupManager.getResourceGroupStateInfo( + new ResourceGroupId( + Arrays.stream(resourceGroupIdString.split("/")) + .map(ResourceGroupStateInfoResource::urlDecode) + .collect(toImmutableList()))); + } + catch (NoSuchElementException e) { + throw new WebApplicationException(NOT_FOUND); + } + } + throw new WebApplicationException(NOT_FOUND); + } + + private static String urlDecode(String value) + { + try { + return URLDecoder.decode(value, "UTF-8"); + } + catch (UnsupportedEncodingException e) { + throw new WebApplicationException(BAD_REQUEST); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java b/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java index 2e532faebd2ab..74efbfb91a086 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.server; +import com.facebook.presto.client.NodeVersion; import com.facebook.presto.client.ServerInfo; import com.facebook.presto.spi.NodeState; +import io.airlift.node.NodeInfo; import javax.inject.Inject; import javax.ws.rs.Consumes; @@ -26,8 +28,11 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import java.util.Optional; + import static com.facebook.presto.spi.NodeState.ACTIVE; import static com.facebook.presto.spi.NodeState.SHUTTING_DOWN; +import static io.airlift.units.Duration.nanosSince; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static javax.ws.rs.core.MediaType.APPLICATION_JSON; @@ -37,13 +42,18 @@ @Path("/v1/info") public class ServerInfoResource { - private final ServerInfo serverInfo; + private final NodeVersion version; + private final String environment; + private final boolean coordinator; private final GracefulShutdownHandler shutdownHandler; + private final long startTime = System.nanoTime(); @Inject - public ServerInfoResource(ServerInfo serverInfo, GracefulShutdownHandler shutdownHandler) + public ServerInfoResource(NodeVersion nodeVersion, NodeInfo nodeInfo, ServerConfig serverConfig, GracefulShutdownHandler shutdownHandler) { - this.serverInfo = requireNonNull(serverInfo, "serverInfo is null"); + this.version = requireNonNull(nodeVersion, "nodeVersion is null"); + this.environment = requireNonNull(nodeInfo, "nodeInfo is null").getEnvironment(); + this.coordinator = requireNonNull(requireNonNull(serverConfig, "serverConfig is null").isCoordinator()); this.shutdownHandler = requireNonNull(shutdownHandler, "shutdownHandler is null"); } @@ -51,7 +61,7 @@ public ServerInfoResource(ServerInfo serverInfo, GracefulShutdownHandler shutdow @Produces(APPLICATION_JSON) public ServerInfo getServerInfo() { - return serverInfo; + return new ServerInfo(version, environment, coordinator, Optional.of(nanosSince(startTime))); } @PUT @@ -99,7 +109,7 @@ public NodeState getServerState() @Produces(TEXT_PLAIN) public Response getServerCoordinator() { - if (serverInfo.isCoordinator()) { + if (coordinator) { return Response.ok().build(); } // return 404 to allow load balancers to only send traffic to the coordinator diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index 46e31133bae21..bd88e4c5020bc 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -19,9 +19,10 @@ import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.block.BlockJsonSerde; import com.facebook.presto.client.NodeVersion; -import com.facebook.presto.client.ServerInfo; import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.connector.system.SystemConnectorModule; +import com.facebook.presto.cost.CoefficientBasedCostCalculator; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.event.query.QueryMonitor; import com.facebook.presto.event.query.QueryMonitorConfig; import com.facebook.presto.execution.LocationFactory; @@ -121,6 +122,7 @@ import com.facebook.presto.type.TypeDeserializer; import com.facebook.presto.type.TypeRegistry; import com.facebook.presto.util.FinalizerService; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.inject.Binder; import com.google.inject.Provides; @@ -129,14 +131,17 @@ import io.airlift.concurrent.BoundedExecutor; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.airlift.discovery.client.ServiceDescriptor; -import io.airlift.node.NodeInfo; +import io.airlift.http.client.HttpClientConfig; import io.airlift.slice.Slice; import io.airlift.stats.PauseMeter; import io.airlift.units.DataSize; import io.airlift.units.Duration; +import javax.annotation.PreDestroy; +import javax.inject.Inject; import javax.inject.Singleton; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; @@ -194,6 +199,12 @@ protected void setup(Binder binder) })); } + InternalCommunicationConfig internalCommunicationConfig = buildConfigObject(InternalCommunicationConfig.class); + configBinder(binder).bindConfigGlobalDefaults(HttpClientConfig.class, config -> { + config.setKeyStorePath(internalCommunicationConfig.getKeyStorePath()); + config.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword()); + }); + configBinder(binder).bindConfig(FeaturesConfig.class); binder.bind(SqlParser.class).in(Scopes.SINGLETON); @@ -337,6 +348,9 @@ protected void setup(Binder binder) binder.bind(MetadataManager.class).in(Scopes.SINGLETON); binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); + // statistics calculator + binder.bind(CostCalculator.class).to(CoefficientBasedCostCalculator.class).in(Scopes.SINGLETON); + // type binder.bind(TypeRegistry.class).in(Scopes.SINGLETON); binder.bind(TypeManager.class).to(TypeRegistry.class).in(Scopes.SINGLETON); @@ -428,13 +442,9 @@ protected void setup(Binder binder) newExporter(binder).export(SpillerFactory.class).withGeneratedName(); binder.bind(LocalSpillManager.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(NodeSpillConfig.class); - } - @Provides - @Singleton - public static ServerInfo createServerInfo(NodeVersion nodeVersion, NodeInfo nodeInfo, ServerConfig serverConfig) - { - return new ServerInfo(nodeVersion, nodeInfo.getEnvironment(), serverConfig.isCoordinator()); + // cleanup + binder.bind(ExecutorCleanup.class).in(Scopes.SINGLETON); } @Provides @@ -521,4 +531,31 @@ public State getState(HostAddress hostAddress) }); } } + + public static class ExecutorCleanup + { + private final List executors; + + @Inject + public ExecutorCleanup( + @ForExchange ScheduledExecutorService exchangeExecutor, + @ForAsyncHttp ExecutorService httpResponseExecutor, + @ForAsyncHttp ScheduledExecutorService httpTimeoutExecutor, + @ForTransactionManager ExecutorService transactionFinishingExecutor, + @ForTransactionManager ScheduledExecutorService transactionIdleExecutor) + { + executors = ImmutableList.of( + exchangeExecutor, + httpResponseExecutor, + httpTimeoutExecutor, + transactionFinishingExecutor, + transactionIdleExecutor); + } + + @PreDestroy + public void shutdown() + { + executors.forEach(ExecutorService::shutdownNow); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java b/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java index 7f36b1e8446f0..684a77b642cf0 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/StatementResource.java @@ -534,11 +534,12 @@ private synchronized Iterable> getData(Duration maxWait) maxWait = new Duration(0, MILLISECONDS); } - if (bytes == 0) { + List rowIterables = pages.build(); + if (rowIterables.isEmpty()) { return null; } - return Iterables.concat(pages.build()); + return Iterables.concat(rowIterables); } private static boolean isQueryStarted(QueryInfo queryInfo) @@ -624,7 +625,7 @@ private static StatementStats toStatementStats(QueryInfo queryInfo) .setNodes(globalUniqueNodes(outputStage).size()) .setTotalSplits(queryStats.getTotalDrivers()) .setQueuedSplits(queryStats.getQueuedDrivers()) - .setRunningSplits(queryStats.getRunningDrivers()) + .setRunningSplits(queryStats.getRunningDrivers() + queryStats.getBlockedDrivers()) .setCompletedSplits(queryStats.getCompletedDrivers()) .setUserTimeMillis(queryStats.getTotalUserTime().toMillis()) .setCpuTimeMillis(queryStats.getTotalCpuTime().toMillis()) @@ -662,7 +663,7 @@ private static StageStats toStageStats(StageInfo stageInfo) .setNodes(uniqueNodes.size()) .setTotalSplits(stageStats.getTotalDrivers()) .setQueuedSplits(stageStats.getQueuedDrivers()) - .setRunningSplits(stageStats.getRunningDrivers()) + .setRunningSplits(stageStats.getRunningDrivers() + stageStats.getBlockedDrivers()) .setCompletedSplits(stageStats.getCompletedDrivers()) .setUserTimeMillis(stageStats.getTotalUserTime().toMillis()) .setCpuTimeMillis(stageStats.getTotalCpuTime().toMillis()) diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java index 5f5b56983db70..878fb511508a8 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java @@ -17,6 +17,7 @@ import com.facebook.presto.execution.StageId; import com.facebook.presto.execution.TaskId; import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.QueryId; import io.airlift.http.server.HttpServerInfo; @@ -35,15 +36,15 @@ public class HttpLocationFactory private final URI baseUri; @Inject - public HttpLocationFactory(InternalNodeManager nodeManager, HttpServerInfo httpServerInfo) + public HttpLocationFactory(InternalNodeManager nodeManager, HttpServerInfo httpServerInfo, InternalCommunicationConfig config) { - this(nodeManager, httpServerInfo.getHttpUri()); + this(nodeManager, config.isHttpsRequired() ? httpServerInfo.getHttpsUri() : httpServerInfo.getHttpUri()); } public HttpLocationFactory(InternalNodeManager nodeManager, URI baseUri) { - this.nodeManager = nodeManager; - this.baseUri = baseUri; + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.baseUri = requireNonNull(baseUri, "baseUri is null"); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTask.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTask.java index 254365b43d78e..e1ee1cafcc778 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTask.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTask.java @@ -618,27 +618,30 @@ private void doScheduleAsyncCleanupRequest(Backoff cleanupBackoff, Request reque @Override public void onSuccess(JsonResponse result) { - updateTaskInfo(result.getValue()); + try { + updateTaskInfo(result.getValue()); + } + finally { + if (!getTaskInfo().getTaskStatus().getState().isDone()) { + cleanUpLocally(); + } + } } @Override public void onFailure(Throwable t) { if (t instanceof RejectedExecutionException) { - // client has been shutdown + // TODO: we should only give up retrying when the client has been shutdown + logError(t, "Unable to %s task at %s. Got RejectedExecutionException.", action, request.getUri()); + cleanUpLocally(); return; } // record failure if (cleanupBackoff.failure()) { - logError(t, "Unable to %s task at %s", action, request.getUri()); - // Update the taskInfo with the new taskStatus. - // This is required because the query state machine depends on TaskInfo (instead of task status) - // to transition its own state. - // Also, since this TaskInfo is updated in the client the "finalInfo" flag will not be set, - // indicating that the stats are stale. - // TODO: Update the query state machine and stage state machine to depend on TaskStatus instead - updateTaskInfo(getTaskInfo().withTaskStatus(getTaskStatus())); + logError(t, "Unable to %s task at %s. Back off depleted.", action, request.getUri()); + cleanUpLocally(); return; } @@ -651,6 +654,27 @@ public void onFailure(Throwable t) errorScheduledExecutor.schedule(() -> doScheduleAsyncCleanupRequest(cleanupBackoff, request, action), delayNanos, NANOSECONDS); } } + + private void cleanUpLocally() + { + // Update the taskInfo with the new taskStatus. + + // Generally, we send a cleanup request to the worker, and update the TaskInfo on + // the coordinator based on what we fetched from the worker. If we somehow cannot + // get the cleanup request to the worker, the TaskInfo that we fetch for the worker + // likely will not say the task is done however many times we try. In this case, + // we have to set the local query info directly so that we stop trying to fetch + // updated TaskInfo from the worker. This way, the task on the worker eventually + // expires due to lack of activity. + + // This is required because the query state machine depends on TaskInfo (instead of task status) + // to transition its own state. + // TODO: Update the query state machine and stage state machine to depend on TaskStatus instead + + // Since this TaskInfo is updated in the client the "complete" flag will not be set, + // indicating that the stats may not reflect the final stats on the worker. + updateTaskInfo(getTaskInfo().withTaskStatus(getTaskStatus())); + } }, executor); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/SimpleHttpResponseHandler.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/SimpleHttpResponseHandler.java index 62933e073849a..7f87ae067a6f2 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/SimpleHttpResponseHandler.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/SimpleHttpResponseHandler.java @@ -67,6 +67,9 @@ else if (response.getStatusCode() == HttpStatus.SERVICE_UNAVAILABLE.code()) { response.getResponseBody())); } } + else { + cause = new PrestoException(REMOTE_TASK_ERROR, format("Unexpected response from %s", uri), cause); + } callback.fatal(cause); } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/KerberosConfig.java b/presto-main/src/main/java/com/facebook/presto/server/security/KerberosConfig.java index 5b02299805e10..b12308b83ac45 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/KerberosConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/KerberosConfig.java @@ -15,6 +15,8 @@ import io.airlift.configuration.Config; +import javax.validation.constraints.NotNull; + import java.io.File; public class KerberosConfig @@ -23,6 +25,7 @@ public class KerberosConfig private String serviceName; private File keytab; + @NotNull public File getKerberosConfig() { return kerberosConfig; @@ -35,6 +38,7 @@ public KerberosConfig setKerberosConfig(File kerberosConfig) return this; } + @NotNull public String getServiceName() { return serviceName; diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 11c499610f26f..40909d6069a1a 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -15,6 +15,7 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.eventlistener.EventListenerManager; import com.facebook.presto.execution.QueryManager; import com.facebook.presto.execution.TaskManager; @@ -100,6 +101,7 @@ public class TestingPrestoServer private final CatalogManager catalogManager; private final TransactionManager transactionManager; private final Metadata metadata; + private final CostCalculator costCalculator; private final TestingAccessControlManager accessControl; private final ProcedureTester procedureTester; private final Optional resourceGroupManager; @@ -249,6 +251,7 @@ public TestingPrestoServer(boolean coordinator, catalogManager = injector.getInstance(CatalogManager.class); transactionManager = injector.getInstance(TransactionManager.class); metadata = injector.getInstance(Metadata.class); + costCalculator = injector.getInstance(CostCalculator.class); accessControl = injector.getInstance(TestingAccessControlManager.class); procedureTester = injector.getInstance(ProcedureTester.class); splitManager = injector.getInstance(SplitManager.class); @@ -346,6 +349,11 @@ public Metadata getMetadata() return metadata; } + public CostCalculator getCostCalculator() + { + return costCalculator; + } + public TestingAccessControlManager getAccessControl() { return accessControl; diff --git a/presto-main/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java b/presto-main/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java index 5ef585785feff..f65b81775e930 100644 --- a/presto-main/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java @@ -60,7 +60,7 @@ public class FileSingleStreamSpillerFactory private final PagesSerdeFactory serdeFactory; private final List spillPaths; private final SpillerStats spillerStats; - private final double minimumFreeSpaceThreshold; + private final double maxUsedSpaceThreshold; private int roundRobinIndex; @Inject @@ -101,7 +101,7 @@ public FileSingleStreamSpillerFactory( format("spill path %s is not writable; adjust experimental.spiller-spill-path config property or filesystem permissions", path)); } }); - this.minimumFreeSpaceThreshold = requireNonNull(maxUsedSpaceThreshold, "maxUsedSpaceThreshold can not be null"); + this.maxUsedSpaceThreshold = requireNonNull(maxUsedSpaceThreshold, "maxUsedSpaceThreshold can not be null"); this.roundRobinIndex = 0; } @@ -153,7 +153,7 @@ private boolean hasEnoughDiskSpace(Path path) { try { FileStore fileStore = getFileStore(path); - return fileStore.getUsableSpace() > fileStore.getTotalSpace() * (1.0 - minimumFreeSpaceThreshold); + return fileStore.getUsableSpace() > fileStore.getTotalSpace() * (1.0 - maxUsedSpaceThreshold); } catch (IOException e) { throw new PrestoException(OUT_OF_SPILL_SPACE, "Cannot determine free space for spill", e); diff --git a/presto-main/src/main/java/com/facebook/presto/connector/EmptySplitHandleResolver.java b/presto-main/src/main/java/com/facebook/presto/split/EmptySplitHandleResolver.java similarity index 94% rename from presto-main/src/main/java/com/facebook/presto/connector/EmptySplitHandleResolver.java rename to presto-main/src/main/java/com/facebook/presto/split/EmptySplitHandleResolver.java index db14829bcc699..610c05f2cc6df 100644 --- a/presto-main/src/main/java/com/facebook/presto/connector/EmptySplitHandleResolver.java +++ b/presto-main/src/main/java/com/facebook/presto/split/EmptySplitHandleResolver.java @@ -11,14 +11,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.connector; +package com.facebook.presto.split; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; -import com.facebook.presto.split.EmptySplit; public class EmptySplitHandleResolver implements ConnectorHandleResolver diff --git a/presto-main/src/main/java/com/facebook/presto/split/EmptySplitPageSource.java b/presto-main/src/main/java/com/facebook/presto/split/EmptySplitPageSource.java new file mode 100644 index 0000000000000..c4a7c24d36036 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/split/EmptySplitPageSource.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.split; + +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.UpdatablePageSource; +import com.facebook.presto.spi.block.Block; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +import static java.util.concurrent.CompletableFuture.completedFuture; + +public class EmptySplitPageSource + implements UpdatablePageSource +{ + @Override + public void deleteRows(Block rowIds) + { + throw new UnsupportedOperationException("deleteRows called on EmptySplitPageSource"); + } + + @Override + public CompletableFuture> finish() + { + return completedFuture(ImmutableList.of()); + } + + @Override + public long getTotalBytes() + { + return 0; + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public boolean isFinished() + { + return true; + } + + @Override + public Page getNextPage() + { + return null; + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public void close() {} +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java index b5856d0d65d1c..4ecc694ec6a00 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -250,7 +250,7 @@ public static Function expressionOrNullSymbols(final Pre resultDisjunct.add(expression); for (Predicate nullSymbolScope : nullSymbolScopes) { - List symbols = DependencyExtractor.extractUnique(expression).stream() + List symbols = SymbolsExtractor.extractUnique(expression).stream() .filter(nullSymbolScope) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregateExtractor.java deleted file mode 100644 index 0566f1032df5d..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregateExtractor.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.analyzer; - -import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; -import com.facebook.presto.sql.tree.FunctionCall; -import com.google.common.collect.ImmutableList; - -import java.util.List; - -import static java.util.Objects.requireNonNull; - -class AggregateExtractor - extends DefaultExpressionTraversalVisitor -{ - private final FunctionRegistry functionRegistry; - - private final ImmutableList.Builder aggregates = ImmutableList.builder(); - - public AggregateExtractor(FunctionRegistry functionRegistry) - { - this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null"); - } - - @Override - protected Void visitFunctionCall(FunctionCall node, Void context) - { - if ((functionRegistry.isAggregationFunction(node.getName()) || node.getFilter().isPresent()) && !node.getWindow().isPresent()) { - aggregates.add(node); - return null; - } - - return super.visitFunctionCall(node, null); - } - - public List getAggregates() - { - return aggregates.build(); - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java index c1e367782f883..5bf481a5e21a5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java @@ -33,6 +33,7 @@ import com.facebook.presto.sql.tree.Extract; import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.InListExpression; @@ -44,6 +45,7 @@ import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.Parameter; @@ -67,16 +69,20 @@ import java.util.Set; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; -import static com.facebook.presto.sql.analyzer.LambdaReferenceExtractor.hasReferencesToLambdaArgument; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractWindowFunctions; +import static com.facebook.presto.sql.analyzer.FreeLambdaReferenceExtractor.hasFreeReferencesToLambdaArgument; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.getReferencesToScope; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.hasReferencesToScope; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.isFieldFromScope; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MUST_BE_AGGREGATE_OR_GROUP_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MUST_BE_AGGREGATION_FUNCTION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NESTED_AGGREGATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NESTED_WINDOW; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -91,7 +97,7 @@ class AggregationAnalyzer // fields and expressions in the group by clause private final Set groupingFields; private final List expressions; - private final Map columnReferences; + private final Map, FieldId> columnReferences; private final Metadata metadata; private final Analysis analysis; @@ -141,6 +147,7 @@ private AggregationAnalyzer(List groupByExpressions, Scope sourceSco this.columnReferences = analysis.getColumnReferenceFields(); this.groupingFields = groupByExpressions.stream() + .map(NodeRef::of) .filter(columnReferences::containsKey) .map(columnReferences::get) .collect(toImmutableSet()); @@ -307,28 +314,23 @@ protected Boolean visitFunctionCall(FunctionCall node, Void context) { if (metadata.isAggregationFunction(node.getName())) { if (!node.getWindow().isPresent()) { - AggregateExtractor aggregateExtractor = new AggregateExtractor(metadata.getFunctionRegistry()); - WindowFunctionExtractor windowExtractor = new WindowFunctionExtractor(); + List aggregateFunctions = extractAggregateFunctions(node.getArguments(), metadata.getFunctionRegistry()); + List windowFunctions = extractWindowFunctions(node.getArguments()); - for (Expression argument : node.getArguments()) { - aggregateExtractor.process(argument, null); - windowExtractor.process(argument, null); - } - - if (!aggregateExtractor.getAggregates().isEmpty()) { + if (!aggregateFunctions.isEmpty()) { throw new SemanticException(NESTED_AGGREGATION, node, "Cannot nest aggregations inside aggregation '%s': %s", node.getName(), - aggregateExtractor.getAggregates()); + aggregateFunctions); } - if (!windowExtractor.getWindowFunctions().isEmpty()) { + if (!windowFunctions.isEmpty()) { throw new SemanticException(NESTED_WINDOW, node, "Cannot nest window functions inside aggregation '%s': %s", node.getName(), - windowExtractor.getWindowFunctions()); + windowFunctions); } if (node.getFilter().isPresent() && node.isDistinct()) { @@ -340,7 +342,11 @@ protected Boolean visitFunctionCall(FunctionCall node, Void context) // ensure that no output fields are referenced from ORDER BY clause if (orderByScope.isPresent()) { - node.getArguments().stream().forEach(AggregationAnalyzer.this::verifyNoOrderByReferencesToOutputColumns); + node.getArguments().stream() + .forEach(argument -> verifyNoOrderByReferencesToOutputColumns( + argument, + REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION, + "Invalid reference to output projection attribute from ORDER BY aggregation")); } return true; @@ -369,7 +375,12 @@ protected Boolean visitLambdaExpression(LambdaExpression node, Void context) @Override protected Boolean visitBindExpression(BindExpression node, Void context) { - return process(node.getValue(), context) && process(node.getFunction(), context); + for (Expression value : node.getValues()) { + if (!process(value, context)) { + return false; + } + } + return process(node.getFunction(), context); } @Override @@ -423,7 +434,7 @@ public Boolean visitWindowFrame(WindowFrame node, Void context) @Override protected Boolean visitIdentifier(Identifier node, Void context) { - if (analysis.getLambdaArgumentReferences().containsKey(node)) { + if (analysis.getLambdaArgumentReferences().containsKey(NodeRef.of(node))) { return true; } return isGroupingKey(node); @@ -432,7 +443,7 @@ protected Boolean visitIdentifier(Identifier node, Void context) @Override protected Boolean visitDereferenceExpression(DereferenceExpression node, Void context) { - if (columnReferences.containsKey(node)) { + if (columnReferences.containsKey(NodeRef.of(node))) { return isGroupingKey(node); } @@ -442,7 +453,7 @@ protected Boolean visitDereferenceExpression(DereferenceExpression node, Void co private boolean isGroupingKey(Expression node) { - FieldId fieldId = columnReferences.get(node); + FieldId fieldId = columnReferences.get(NodeRef.of(node)); requireNonNull(fieldId, () -> "No FieldId for " + node); if (orderByScope.isPresent() && isFieldFromScope(fieldId, orderByScope.get())) { @@ -459,7 +470,7 @@ protected Boolean visitFieldReference(FieldReference node, Void context) return true; } - FieldId fieldId = requireNonNull(columnReferences.get(node), "No FieldId for FieldReference"); + FieldId fieldId = requireNonNull(columnReferences.get(NodeRef.of(node)), "No FieldId for FieldReference"); boolean inGroup = groupingFields.contains(fieldId); if (!inGroup) { Field field = sourceScope.getRelationType().getFieldByIndex(node.getFieldIndex()); @@ -568,12 +579,37 @@ public Boolean visitParameter(Parameter node, Void context) return process(parameters.get(node.getPosition()), context); } + public Boolean visitGroupingOperation(GroupingOperation node, Void context) + { + // ensure that no output fields are referenced from ORDER BY clause + if (orderByScope.isPresent()) { + node.getGroupingColumns().forEach(groupingColumn -> verifyNoOrderByReferencesToOutputColumns( + groupingColumn, + REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING, + "Invalid reference to output of SELECT clause from grouping() expression in ORDER BY" + )); + } + + Optional argumentNotInGroupBy = node.getGroupingColumns().stream() + .filter(argument -> !columnReferences.containsKey(NodeRef.of(argument)) || !isGroupingKey(argument)) + .findAny(); + if (argumentNotInGroupBy.isPresent()) { + throw new SemanticException( + INVALID_PROCEDURE_ARGUMENTS, + node, + "The arguments to GROUPING() must be expressions referenced by the GROUP BY at the associated query level. Mismatch due to %s.", + argumentNotInGroupBy.get() + ); + } + return true; + } + @Override public Boolean process(Node node, @Nullable Void context) { if (expressions.stream().anyMatch(node::equals) && (!orderByScope.isPresent() || !hasOrderByReferencesToOutputColumns(node)) - && !hasReferencesToLambdaArgument(node, analysis)) { + && !hasFreeReferencesToLambdaArgument(node, analysis)) { return true; } @@ -586,12 +622,12 @@ private boolean hasOrderByReferencesToOutputColumns(Node node) return hasReferencesToScope(node, analysis, orderByScope.get()); } - private void verifyNoOrderByReferencesToOutputColumns(Node node) + private void verifyNoOrderByReferencesToOutputColumns(Node node, SemanticErrorCode errorCode, String errorString) { getReferencesToScope(node, analysis, orderByScope.get()) .findFirst() .ifPresent(expression -> { - throw new SemanticException(REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION, expression, "Invalid reference to output projection attribute from ORDER BY aggregation"); + throw new SemanticException(errorCode, expression, errorString); }); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index bb19bcc62f8c8..ae83b6bec0fdb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -21,11 +21,13 @@ import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.Join; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.Query; @@ -35,69 +37,77 @@ import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.Table; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; -import com.google.common.base.Preconditions; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ListMultimap; +import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; import java.util.ArrayDeque; import java.util.Collection; import java.util.Deque; -import java.util.IdentityHashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import static com.facebook.presto.util.MoreLists.listOfListsCopy; import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Collections.newSetFromMap; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Collections.emptyList; +import static java.util.Collections.unmodifiableCollection; +import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; import static java.util.Collections.unmodifiableSet; import static java.util.Objects.requireNonNull; public class Analysis { + @Nullable private final Statement root; private final List parameters; private String updateType; - private final IdentityLinkedHashMap namedQueries = new IdentityLinkedHashMap<>(); + private final Map, Query> namedQueries = new LinkedHashMap<>(); - private final IdentityLinkedHashMap scopes = new IdentityLinkedHashMap<>(); - private final IdentityHashMap columnReferences = new IdentityHashMap<>(); + private final Map, Scope> scopes = new LinkedHashMap<>(); + private final Map, FieldId> columnReferences = new LinkedHashMap<>(); - private final IdentityLinkedHashMap> aggregates = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap> orderByAggregates = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap>> groupByExpressions = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap where = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap having = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap> orderByExpressions = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap> outputExpressions = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap> windowFunctions = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap> orderByWindowFunctions = new IdentityLinkedHashMap<>(); + private final Map, List> aggregates = new LinkedHashMap<>(); + private final Map, List> orderByAggregates = new LinkedHashMap<>(); + private final Map, List>> groupByExpressions = new LinkedHashMap<>(); + private final Map, Expression> where = new LinkedHashMap<>(); + private final Map, Expression> having = new LinkedHashMap<>(); + private final Map, List> orderByExpressions = new LinkedHashMap<>(); + private final Map, List> outputExpressions = new LinkedHashMap<>(); + private final Map, List> windowFunctions = new LinkedHashMap<>(); + private final Map, List> orderByWindowFunctions = new LinkedHashMap<>(); - private final IdentityLinkedHashMap joins = new IdentityLinkedHashMap<>(); - private final ListMultimap inPredicatesSubqueries = ArrayListMultimap.create(); - private final ListMultimap scalarSubqueries = ArrayListMultimap.create(); - private final ListMultimap existsSubqueries = ArrayListMultimap.create(); - private final ListMultimap quantifiedComparisonSubqueries = ArrayListMultimap.create(); + private final Map, Expression> joins = new LinkedHashMap<>(); + private final ListMultimap, InPredicate> inPredicatesSubqueries = ArrayListMultimap.create(); + private final ListMultimap, SubqueryExpression> scalarSubqueries = ArrayListMultimap.create(); + private final ListMultimap, ExistsPredicate> existsSubqueries = ArrayListMultimap.create(); + private final ListMultimap, QuantifiedComparisonExpression> quantifiedComparisonSubqueries = ArrayListMultimap.create(); - private final IdentityLinkedHashMap tables = new IdentityLinkedHashMap<>(); + private final Map, TableHandle> tables = new LinkedHashMap<>(); - private final IdentityLinkedHashMap types = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap coercions = new IdentityLinkedHashMap<>(); - private final Set typeOnlyCoercions = newSetFromMap(new IdentityLinkedHashMap<>()); - private final IdentityLinkedHashMap relationCoercions = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap functionSignature = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap lambdaArgumentReferences = new IdentityLinkedHashMap<>(); + private final Map, Type> types = new LinkedHashMap<>(); + private final Map, Type> coercions = new LinkedHashMap<>(); + private final Set> typeOnlyCoercions = new LinkedHashSet<>(); + private final Map, List> relationCoercions = new LinkedHashMap<>(); + private final Map, Signature> functionSignature = new LinkedHashMap<>(); + private final Map, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>(); - private final IdentityLinkedHashMap columns = new IdentityLinkedHashMap<>(); + private final Map columns = new LinkedHashMap<>(); - private final IdentityLinkedHashMap sampleRatios = new IdentityLinkedHashMap<>(); + private final Map, Double> sampleRatios = new LinkedHashMap<>(); + + private final Map, List> groupingOperations = new LinkedHashMap<>(); // for create table private Optional createTableDestination = Optional.empty(); @@ -114,12 +124,12 @@ public class Analysis // for recursive view detection private final Deque

tablesForView = new ArrayDeque<>(); - public Analysis(Statement root, List parameters, boolean isDescribe) + public Analysis(@Nullable Statement root, List parameters, boolean isDescribe) { requireNonNull(parameters); this.root = root; - this.parameters = parameters; + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); this.isDescribe = isDescribe; } @@ -160,200 +170,200 @@ public void setCreateTableAsSelectNoOp(boolean createTableAsSelectNoOp) public void setAggregates(QuerySpecification node, List aggregates) { - this.aggregates.put(node, aggregates); + this.aggregates.put(NodeRef.of(node), ImmutableList.copyOf(aggregates)); } public List getAggregates(QuerySpecification query) { - return aggregates.get(query); + return aggregates.get(NodeRef.of(query)); } public void setOrderByAggregates(OrderBy node, List aggregates) { - this.orderByAggregates.put(node, ImmutableList.copyOf(aggregates)); + this.orderByAggregates.put(NodeRef.of(node), ImmutableList.copyOf(aggregates)); } - public List getOrderByAggregates(OrderBy query) + public List getOrderByAggregates(OrderBy node) { - return orderByAggregates.get(query); + return orderByAggregates.get(NodeRef.of(node)); } - public IdentityLinkedHashMap getTypes() + public Map, Type> getTypes() { - return new IdentityLinkedHashMap<>(types); + return unmodifiableMap(types); } public Type getType(Expression expression) { - checkArgument(types.containsKey(expression), "Expression not analyzed: %s", expression); - return types.get(expression); + NodeRef key = NodeRef.of(expression); + checkArgument(types.containsKey(key), "Expression not analyzed: %s", expression); + return types.get(key); } public Type getTypeWithCoercions(Expression expression) { - checkArgument(types.containsKey(expression), "Expression not analyzed: %s", expression); - if (coercions.containsKey(expression)) { - return coercions.get(expression); + NodeRef key = NodeRef.of(expression); + checkArgument(types.containsKey(key), "Expression not analyzed: %s", expression); + if (coercions.containsKey(key)) { + return coercions.get(key); } - return types.get(expression); + return types.get(key); } public Type[] getRelationCoercion(Relation relation) { - return relationCoercions.get(relation); + return Optional.ofNullable(relationCoercions.get(NodeRef.of(relation))) + .map(types -> types.stream().toArray(Type[]::new)) + .orElse(null); } public void addRelationCoercion(Relation relation, Type[] types) { - relationCoercions.put(relation, types); + relationCoercions.put(NodeRef.of(relation), ImmutableList.copyOf(types)); } - public IdentityLinkedHashMap getCoercions() + public Map, Type> getCoercions() { - return coercions; + return unmodifiableMap(coercions); } public Type getCoercion(Expression expression) { - return coercions.get(expression); + return coercions.get(NodeRef.of(expression)); } - public void addLambdaArgumentReferences(IdentityLinkedHashMap lambdaArgumentReferences) + public void addLambdaArgumentReferences(Map, LambdaArgumentDeclaration> lambdaArgumentReferences) { this.lambdaArgumentReferences.putAll(lambdaArgumentReferences); } public LambdaArgumentDeclaration getLambdaArgumentReference(Identifier identifier) { - return lambdaArgumentReferences.get(identifier); + return lambdaArgumentReferences.get(NodeRef.of(identifier)); } - public IdentityLinkedHashMap getLambdaArgumentReferences() + public Map, LambdaArgumentDeclaration> getLambdaArgumentReferences() { - return lambdaArgumentReferences; + return unmodifiableMap(lambdaArgumentReferences); } public void setGroupingSets(QuerySpecification node, List> expressions) { - groupByExpressions.put(node, expressions); + groupByExpressions.put(NodeRef.of(node), listOfListsCopy(expressions)); } public boolean isTypeOnlyCoercion(Expression expression) { - return typeOnlyCoercions.contains(expression); + return typeOnlyCoercions.contains(NodeRef.of(expression)); } public List> getGroupingSets(QuerySpecification node) { - return groupByExpressions.get(node); + return groupByExpressions.get(NodeRef.of(node)); } public void setWhere(Node node, Expression expression) { - where.put(node, expression); + where.put(NodeRef.of(node), expression); } public Expression getWhere(QuerySpecification node) { - return where.get(node); + return where.get(NodeRef.of(node)); } public void setOrderByExpressions(Node node, List items) { - orderByExpressions.put(node, items); + orderByExpressions.put(NodeRef.of(node), ImmutableList.copyOf(items)); } public List getOrderByExpressions(Node node) { - return orderByExpressions.get(node); + return orderByExpressions.get(NodeRef.of(node)); } public void setOutputExpressions(Node node, List expressions) { - outputExpressions.put(node, expressions); + outputExpressions.put(NodeRef.of(node), ImmutableList.copyOf(expressions)); } public List getOutputExpressions(Node node) { - return outputExpressions.get(node); + return outputExpressions.get(NodeRef.of(node)); } public void setHaving(QuerySpecification node, Expression expression) { - having.put(node, expression); + having.put(NodeRef.of(node), expression); } public void setJoinCriteria(Join node, Expression criteria) { - joins.put(node, criteria); + joins.put(NodeRef.of(node), criteria); } public Expression getJoinCriteria(Join join) { - return joins.get(join); + return joins.get(NodeRef.of(join)); } public void recordSubqueries(Node node, ExpressionAnalysis expressionAnalysis) { - this.inPredicatesSubqueries.putAll(node, expressionAnalysis.getSubqueryInPredicates()); - this.scalarSubqueries.putAll(node, expressionAnalysis.getScalarSubqueries()); - this.existsSubqueries.putAll(node, expressionAnalysis.getExistsSubqueries()); - this.quantifiedComparisonSubqueries.putAll(node, expressionAnalysis.getQuantifiedComparisons()); + NodeRef key = NodeRef.of(node); + this.inPredicatesSubqueries.putAll(key, dereference(expressionAnalysis.getSubqueryInPredicates())); + this.scalarSubqueries.putAll(key, dereference(expressionAnalysis.getScalarSubqueries())); + this.existsSubqueries.putAll(key, dereference(expressionAnalysis.getExistsSubqueries())); + this.quantifiedComparisonSubqueries.putAll(key, dereference(expressionAnalysis.getQuantifiedComparisons())); + } + + private List dereference(Collection> nodeRefs) + { + return nodeRefs.stream() + .map(NodeRef::getNode) + .collect(toImmutableList()); } public List getInPredicateSubqueries(Node node) { - if (inPredicatesSubqueries.containsKey(node)) { - return inPredicatesSubqueries.get(node); - } - return ImmutableList.of(); + return ImmutableList.copyOf(inPredicatesSubqueries.get(NodeRef.of(node))); } public List getScalarSubqueries(Node node) { - if (scalarSubqueries.containsKey(node)) { - return scalarSubqueries.get(node); - } - return ImmutableList.of(); + return ImmutableList.copyOf(scalarSubqueries.get(NodeRef.of(node))); } public List getExistsSubqueries(Node node) { - if (existsSubqueries.containsKey(node)) { - return existsSubqueries.get(node); - } - return ImmutableList.of(); + return ImmutableList.copyOf(existsSubqueries.get(NodeRef.of(node))); } public List getQuantifiedComparisonSubqueries(Node node) { - if (quantifiedComparisonSubqueries.containsKey(node)) { - return quantifiedComparisonSubqueries.get(node); - } - return ImmutableList.of(); + return unmodifiableList(quantifiedComparisonSubqueries.get(NodeRef.of(node))); } public void setWindowFunctions(QuerySpecification node, List functions) { - windowFunctions.put(node, functions); + windowFunctions.put(NodeRef.of(node), ImmutableList.copyOf(functions)); } public List getWindowFunctions(QuerySpecification query) { - return windowFunctions.get(query); + return windowFunctions.get(NodeRef.of(query)); } public void setOrderByWindowFunctions(OrderBy node, List functions) { - orderByWindowFunctions.put(node, ImmutableList.copyOf(functions)); + orderByWindowFunctions.put(NodeRef.of(node), ImmutableList.copyOf(functions)); } public List getOrderByWindowFunctions(OrderBy query) { - return orderByWindowFunctions.get(query); + return orderByWindowFunctions.get(NodeRef.of(query)); } - public void addColumnReferences(IdentityLinkedHashMap columnReferences) + public void addColumnReferences(Map, FieldId> columnReferences) { this.columnReferences.putAll(columnReferences); } @@ -365,8 +375,9 @@ public Scope getScope(Node node) public Optional tryGetScope(Node node) { - if (scopes.containsKey(node)) { - return Optional.of(scopes.get(node)); + NodeRef key = NodeRef.of(node); + if (scopes.containsKey(key)) { + return Optional.of(scopes.get(key)); } return Optional.empty(); @@ -379,7 +390,7 @@ public Scope getRootScope() public void setScope(Node node, Scope scope) { - scopes.put(node, scope); + scopes.put(NodeRef.of(node), scope); } public RelationType getOutputDescriptor() @@ -394,53 +405,53 @@ public RelationType getOutputDescriptor(Node node) public TableHandle getTableHandle(Table table) { - return tables.get(table); + return tables.get(NodeRef.of(table)); } public Collection getTables() { - return tables.values(); + return unmodifiableCollection(tables.values()); } public void registerTable(Table table, TableHandle handle) { - tables.put(table, handle); + tables.put(NodeRef.of(table), handle); } public Signature getFunctionSignature(FunctionCall function) { - return functionSignature.get(function); + return functionSignature.get(NodeRef.of(function)); } - public void addFunctionSignatures(IdentityLinkedHashMap infos) + public void addFunctionSignatures(Map, Signature> infos) { functionSignature.putAll(infos); } - public Set getColumnReferences() + public Set> getColumnReferences() { return unmodifiableSet(columnReferences.keySet()); } - public Map getColumnReferenceFields() + public Map, FieldId> getColumnReferenceFields() { return unmodifiableMap(columnReferences); } - public void addTypes(IdentityLinkedHashMap types) + public void addTypes(Map, Type> types) { this.types.putAll(types); } public void addCoercion(Expression expression, Type type, boolean isTypeOnlyCoercion) { - this.coercions.put(expression, type); + this.coercions.put(NodeRef.of(expression), type); if (isTypeOnlyCoercion) { - this.typeOnlyCoercions.add(expression); + this.typeOnlyCoercions.add(NodeRef.of(expression)); } } - public void addCoercions(IdentityLinkedHashMap coercions, Set typeOnlyCoercions) + public void addCoercions(Map, Type> coercions, Set> typeOnlyCoercions) { this.coercions.putAll(coercions); this.typeOnlyCoercions.addAll(typeOnlyCoercions); @@ -448,7 +459,7 @@ public void addCoercions(IdentityLinkedHashMap coercions, Set< public Expression getHaving(QuerySpecification query) { - return having.get(query); + return having.get(NodeRef.of(query)); } public void setColumn(Field field, ColumnHandle handle) @@ -473,7 +484,7 @@ public Optional getCreateTableDestination() public void setCreateTableProperties(Map createTableProperties) { - this.createTableProperties = createTableProperties; + this.createTableProperties = ImmutableMap.copyOf(createTableProperties); } public Map getCreateTableProperties() @@ -503,7 +514,7 @@ public Optional getInsert() public Query getNamedQuery(Table table) { - return namedQueries.get(table); + return namedQueries.get(NodeRef.of(table)); } public void registerNamedQuery(Table tableReference, Query query) @@ -511,7 +522,7 @@ public void registerNamedQuery(Table tableReference, Query query) requireNonNull(tableReference, "tableReference is null"); requireNonNull(query, "query is null"); - namedQueries.put(tableReference, query); + namedQueries.put(NodeRef.of(tableReference), query); } public void registerTableForView(Table tableReference) @@ -531,13 +542,25 @@ public boolean hasTableInView(Table tableReference) public void setSampleRatio(SampledRelation relation, double ratio) { - sampleRatios.put(relation, ratio); + sampleRatios.put(NodeRef.of(relation), ratio); } public double getSampleRatio(SampledRelation relation) { - Preconditions.checkState(sampleRatios.containsKey(relation), "Sample ratio missing for %s. Broken analysis?", relation); - return sampleRatios.get(relation); + NodeRef key = NodeRef.of(relation); + checkState(sampleRatios.containsKey(key), "Sample ratio missing for %s. Broken analysis?", relation); + return sampleRatios.get(key); + } + + public void setGroupingOperations(QuerySpecification querySpecification, List groupingOperations) + { + this.groupingOperations.put(NodeRef.of(querySpecification), ImmutableList.copyOf(groupingOperations)); + } + + public List getGroupingOperations(QuerySpecification querySpecification) + { + return Optional.ofNullable(groupingOperations.get(NodeRef.of(querySpecification))) + .orElse(emptyList()); } public List getParameters() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java index 3c516021e3bd9..13ddabc652f2d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java @@ -21,6 +21,7 @@ import com.facebook.presto.sql.rewrite.StatementRewrite; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Statement; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -28,7 +29,10 @@ import java.util.List; import java.util.Optional; -import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractExpressions; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractWindowFunctions; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING; import static java.util.Objects.requireNonNull; public class Analyzer @@ -69,18 +73,21 @@ public Analysis analyze(Statement statement, boolean isDescribe) return analysis; } - static void verifyNoAggregatesOrWindowFunctions(FunctionRegistry functionRegistry, Expression predicate, String clause) + static void verifyNoAggregateWindowOrGroupingFunctions(FunctionRegistry functionRegistry, Expression predicate, String clause) { - AggregateExtractor extractor = new AggregateExtractor(functionRegistry); - extractor.process(predicate, null); + List aggregates = extractAggregateFunctions(ImmutableList.of(predicate), functionRegistry); - WindowFunctionExtractor windowExtractor = new WindowFunctionExtractor(); - windowExtractor.process(predicate, null); + List windowExpressions = extractWindowFunctions(ImmutableList.of(predicate)); - List found = ImmutableList.copyOf(Iterables.concat(extractor.getAggregates(), windowExtractor.getWindowFunctions())); + List groupingOperations = extractExpressions(ImmutableList.of(predicate), GroupingOperation.class); + + List found = ImmutableList.copyOf(Iterables.concat( + aggregates, + windowExpressions, + groupingOperations)); if (!found.isEmpty()) { - throw new SemanticException(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, predicate, "%s cannot contain aggregations or window functions: %s", clause, found); + throw new SemanticException(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, predicate, "%s cannot contain aggregations, window functions or grouping operations: %s", clause, found); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java index 4d2cde264819f..983ca806ff221 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalysis.java @@ -19,95 +19,98 @@ import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.SubqueryExpression; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.util.Map; import java.util.Set; import static java.util.Objects.requireNonNull; public class ExpressionAnalysis { - private final IdentityLinkedHashMap expressionTypes; - private final IdentityLinkedHashMap expressionCoercions; - private final Set typeOnlyCoercions; - private final IdentityLinkedHashMap columnReferences; - private final Set subqueryInPredicates; - private final Set scalarSubqueries; - private final Set existsSubqueries; - private final Set quantifiedComparisons; + private final Map, Type> expressionTypes; + private final Map, Type> expressionCoercions; + private final Set> typeOnlyCoercions; + private final Map, FieldId> columnReferences; + private final Set> subqueryInPredicates; + private final Set> scalarSubqueries; + private final Set> existsSubqueries; + private final Set> quantifiedComparisons; // For lambda argument references, maps each QualifiedNameReference to the referenced LambdaArgumentDeclaration - private final IdentityLinkedHashMap lambdaArgumentReferences; + private final Map, LambdaArgumentDeclaration> lambdaArgumentReferences; public ExpressionAnalysis( - IdentityLinkedHashMap expressionTypes, - IdentityLinkedHashMap expressionCoercions, - Set subqueryInPredicates, - Set scalarSubqueries, - Set existsSubqueries, - IdentityLinkedHashMap columnReferences, - Set typeOnlyCoercions, - Set quantifiedComparisons, - IdentityLinkedHashMap lambdaArgumentReferences) + Map, Type> expressionTypes, + Map, Type> expressionCoercions, + Set> subqueryInPredicates, + Set> scalarSubqueries, + Set> existsSubqueries, + Map, FieldId> columnReferences, + Set> typeOnlyCoercions, + Set> quantifiedComparisons, + Map, LambdaArgumentDeclaration> lambdaArgumentReferences) { - this.expressionTypes = requireNonNull(expressionTypes, "expressionTypes is null"); - this.expressionCoercions = requireNonNull(expressionCoercions, "expressionCoercions is null"); - this.typeOnlyCoercions = requireNonNull(typeOnlyCoercions, "typeOnlyCoercions is null"); - this.columnReferences = new IdentityLinkedHashMap<>(requireNonNull(columnReferences, "columnReferences is null")); - this.subqueryInPredicates = requireNonNull(subqueryInPredicates, "subqueryInPredicates is null"); - this.scalarSubqueries = requireNonNull(scalarSubqueries, "subqueryInPredicates is null"); - this.existsSubqueries = requireNonNull(existsSubqueries, "existsSubqueries is null"); - this.quantifiedComparisons = requireNonNull(quantifiedComparisons, "quantifiedComparisons is null"); - this.lambdaArgumentReferences = requireNonNull(lambdaArgumentReferences, "lambdaArgumentReferences is null"); + this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null")); + this.expressionCoercions = ImmutableMap.copyOf(requireNonNull(expressionCoercions, "expressionCoercions is null")); + this.typeOnlyCoercions = ImmutableSet.copyOf(requireNonNull(typeOnlyCoercions, "typeOnlyCoercions is null")); + this.columnReferences = ImmutableMap.copyOf(requireNonNull(columnReferences, "columnReferences is null")); + this.subqueryInPredicates = ImmutableSet.copyOf(requireNonNull(subqueryInPredicates, "subqueryInPredicates is null")); + this.scalarSubqueries = ImmutableSet.copyOf(requireNonNull(scalarSubqueries, "subqueryInPredicates is null")); + this.existsSubqueries = ImmutableSet.copyOf(requireNonNull(existsSubqueries, "existsSubqueries is null")); + this.quantifiedComparisons = ImmutableSet.copyOf(requireNonNull(quantifiedComparisons, "quantifiedComparisons is null")); + this.lambdaArgumentReferences = ImmutableMap.copyOf(requireNonNull(lambdaArgumentReferences, "lambdaArgumentReferences is null")); } public Type getType(Expression expression) { - return expressionTypes.get(expression); + return expressionTypes.get(NodeRef.of(expression)); } - public IdentityLinkedHashMap getExpressionTypes() + public Map, Type> getExpressionTypes() { return expressionTypes; } public Type getCoercion(Expression expression) { - return expressionCoercions.get(expression); + return expressionCoercions.get(NodeRef.of(expression)); } public LambdaArgumentDeclaration getLambdaArgumentReference(Identifier qualifiedNameReference) { - return lambdaArgumentReferences.get(qualifiedNameReference); + return lambdaArgumentReferences.get(NodeRef.of(qualifiedNameReference)); } public boolean isTypeOnlyCoercion(Expression expression) { - return typeOnlyCoercions.contains(expression); + return typeOnlyCoercions.contains(NodeRef.of(expression)); } public boolean isColumnReference(Expression node) { - return columnReferences.containsKey(node); + return columnReferences.containsKey(NodeRef.of(node)); } - public Set getSubqueryInPredicates() + public Set> getSubqueryInPredicates() { return subqueryInPredicates; } - public Set getScalarSubqueries() + public Set> getScalarSubqueries() { return scalarSubqueries; } - public Set getExistsSubqueries() + public Set> getExistsSubqueries() { return existsSubqueries; } - public Set getQuantifiedComparisons() + public Set> getQuantifiedComparisons() { return quantifiedComparisons; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index f7104a4ecbd32..ba3cb4244e4fb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DecimalParseResult; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; @@ -54,6 +55,7 @@ import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.InListExpression; @@ -67,6 +69,7 @@ import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullLiteral; @@ -88,8 +91,6 @@ import com.facebook.presto.sql.tree.WhenClause; import com.facebook.presto.sql.tree.WindowFrame; import com.facebook.presto.type.FunctionType; -import com.facebook.presto.type.RowType; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.SliceUtf8; @@ -98,6 +99,8 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -105,6 +108,8 @@ import java.util.Set; import java.util.function.Function; +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.MAX_NUMBER_GROUPING_ARGUMENTS_BIGINT; +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.MAX_NUMBER_GROUPING_ARGUMENTS_INTEGER; import static com.facebook.presto.spi.function.OperatorType.SUBSCRIPT; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -112,6 +117,7 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.RealType.REAL; +import static com.facebook.presto.spi.type.RowType.RowField; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static com.facebook.presto.spi.type.TimeType.TIME; import static com.facebook.presto.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; @@ -122,10 +128,11 @@ import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; -import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregatesOrWindowFunctions; +import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MULTIPLE_FIELDS_FROM_SUBQUERY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.STANDALONE_LAMBDA; @@ -137,7 +144,6 @@ import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; import static com.facebook.presto.type.JsonType.JSON; -import static com.facebook.presto.type.RowType.RowField; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.DateTimeUtils.parseTimestampLiteral; import static com.facebook.presto.util.DateTimeUtils.timeHasTimeZone; @@ -147,7 +153,8 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; -import static java.util.Collections.newSetFromMap; +import static java.util.Collections.unmodifiableMap; +import static java.util.Collections.unmodifiableSet; import static java.util.Objects.requireNonNull; public class ExpressionAnalyzer @@ -158,17 +165,17 @@ public class ExpressionAnalyzer private final Map symbolTypes; private final boolean isDescribe; - private final IdentityLinkedHashMap resolvedFunctions = new IdentityLinkedHashMap<>(); - private final Set scalarSubqueries = newSetFromMap(new IdentityLinkedHashMap<>()); - private final Set existsSubqueries = newSetFromMap(new IdentityLinkedHashMap<>()); - private final IdentityLinkedHashMap expressionCoercions = new IdentityLinkedHashMap<>(); - private final Set typeOnlyCoercions = newSetFromMap(new IdentityLinkedHashMap<>()); - private final Set subqueryInPredicates = newSetFromMap(new IdentityLinkedHashMap<>()); - private final IdentityLinkedHashMap columnReferences = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap expressionTypes = new IdentityLinkedHashMap<>(); - private final Set quantifiedComparisons = newSetFromMap(new IdentityLinkedHashMap<>()); + private final Map, Signature> resolvedFunctions = new LinkedHashMap<>(); + private final Set> scalarSubqueries = new LinkedHashSet<>(); + private final Set> existsSubqueries = new LinkedHashSet<>(); + private final Map, Type> expressionCoercions = new LinkedHashMap<>(); + private final Set> typeOnlyCoercions = new LinkedHashSet<>(); + private final Set> subqueryInPredicates = new LinkedHashSet<>(); + private final Map, FieldId> columnReferences = new LinkedHashMap<>(); + private final Map, Type> expressionTypes = new LinkedHashMap<>(); + private final Set> quantifiedComparisons = new LinkedHashSet<>(); // For lambda argument references, maps each QualifiedNameReference to the referenced LambdaArgumentDeclaration - private final IdentityLinkedHashMap lambdaArgumentReferences = new IdentityLinkedHashMap<>(); + private final Map, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>(); private final Session session; private final List parameters; @@ -191,14 +198,14 @@ public ExpressionAnalyzer( this.isDescribe = isDescribe; } - public IdentityLinkedHashMap getResolvedFunctions() + public Map, Signature> getResolvedFunctions() { - return resolvedFunctions; + return unmodifiableMap(resolvedFunctions); } - public IdentityLinkedHashMap getExpressionTypes() + public Map, Type> getExpressionTypes() { - return expressionTypes; + return unmodifiableMap(expressionTypes); } public Type setExpressionType(Expression expression, Type type) @@ -206,34 +213,43 @@ public Type setExpressionType(Expression expression, Type type) requireNonNull(expression, "expression cannot be null"); requireNonNull(type, "type cannot be null"); - expressionTypes.put(expression, type); + expressionTypes.put(NodeRef.of(expression), type); return type; } - public IdentityLinkedHashMap getExpressionCoercions() + private Type getExpressionType(Expression expression) { - return expressionCoercions; + requireNonNull(expression, "expression cannot be null"); + + Type type = expressionTypes.get(NodeRef.of(expression)); + checkState(type != null, "Expression not yet analyzed: %s", expression); + return type; + } + + public Map, Type> getExpressionCoercions() + { + return unmodifiableMap(expressionCoercions); } - public Set getTypeOnlyCoercions() + public Set> getTypeOnlyCoercions() { - return typeOnlyCoercions; + return unmodifiableSet(typeOnlyCoercions); } - public Set getSubqueryInPredicates() + public Set> getSubqueryInPredicates() { - return subqueryInPredicates; + return unmodifiableSet(subqueryInPredicates); } - public IdentityLinkedHashMap getColumnReferences() + public Map, FieldId> getColumnReferences() { - return columnReferences; + return unmodifiableMap(columnReferences); } - public IdentityLinkedHashMap getLambdaArgumentReferences() + public Map, LambdaArgumentDeclaration> getLambdaArgumentReferences() { - return lambdaArgumentReferences; + return unmodifiableMap(lambdaArgumentReferences); } public Type analyze(Expression expression, Scope scope) @@ -248,19 +264,19 @@ private Type analyze(Expression expression, Scope scope, Context context) return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(context)); } - public Set getScalarSubqueries() + public Set> getScalarSubqueries() { - return scalarSubqueries; + return unmodifiableSet(scalarSubqueries); } - public Set getExistsSubqueries() + public Set> getExistsSubqueries() { - return existsSubqueries; + return unmodifiableSet(existsSubqueries); } - public Set getQuantifiedComparisons() + public Set> getQuantifiedComparisons() { - return quantifiedComparisons; + return unmodifiableSet(quantifiedComparisons); } private class Visitor @@ -278,7 +294,7 @@ private Visitor(Scope scope) public Type process(Node node, @Nullable StackableAstVisitorContext context) { // don't double process a node - Type type = expressionTypes.get(node); + Type type = expressionTypes.get(NodeRef.of(node)); if (type != null) { return type; } @@ -333,7 +349,7 @@ protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorCon if (context.getContext().isInLambda()) { LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getName()); if (lambdaArgumentDeclaration != null) { - Type result = expressionTypes.get(lambdaArgumentDeclaration); + Type result = getExpressionType(lambdaArgumentDeclaration); return setExpressionType(node, result); } } @@ -347,8 +363,8 @@ protected Type visitIdentifier(Identifier node, StackableAstVisitorContext argumentTypes = argumentTypesBuilder.build(); - Signature function; - try { - function = functionRegistry.resolveFunction(node.getName(), argumentTypes); - } - catch (PrestoException e) { - if (e.getErrorCode().getCode() == StandardErrorCode.FUNCTION_NOT_FOUND.toErrorCode().getCode()) { - throw new SemanticException(SemanticErrorCode.FUNCTION_NOT_FOUND, node, e.getMessage()); - } - if (e.getErrorCode().getCode() == StandardErrorCode.AMBIGUOUS_FUNCTION_CALL.toErrorCode().getCode()) { - throw new SemanticException(SemanticErrorCode.AMBIGUOUS_FUNCTION_CALL, node, e.getMessage()); - } - throw e; - } + Signature function = resolveFunction(node, argumentTypes, functionRegistry); for (int i = 0; i < node.getArguments().size(); i++) { Expression expression = node.getArguments().get(i); @@ -819,7 +823,7 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext context) { - verifyNoAggregatesOrWindowFunctions(functionRegistry, node.getBody(), "Lambda expression"); + verifyNoAggregateWindowOrGroupingFunctions(functionRegistry, node.getBody(), "Lambda expression"); if (!context.getContext().isExpectingLambda()) { throw new SemanticException(STANDALONE_LAMBDA, node, "Lambda expression should always be used inside a function"); } List types = context.getContext().getFunctionInputTypes(); List lambdaArguments = node.getArguments(); + + if (types.size() != lambdaArguments.size()) { + throw new SemanticException(INVALID_PARAMETER_USAGE, node, + format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size())); + } verify(types.size() == lambdaArguments.size()); Map nameToLambdaArgumentDeclarationMap = new HashMap<>(); @@ -1078,18 +1087,24 @@ protected Type visitBindExpression(BindExpression node, StackableAstVisitorConte { verify(context.getContext().isExpectingLambda(), "bind expression found when lambda is not expected"); - List functionInputTypes = ImmutableList.builder() - .add(process(node.getValue(), new StackableAstVisitorContext<>(context.getContext().notExpectingLambda()))) - .addAll(context.getContext().getFunctionInputTypes()) - .build(); + StackableAstVisitorContext innerContext = new StackableAstVisitorContext<>(context.getContext().notExpectingLambda()); + ImmutableList.Builder functionInputTypesBuilder = ImmutableList.builder(); + for (Expression value : node.getValues()) { + functionInputTypesBuilder.add(process(value, innerContext)); + } + functionInputTypesBuilder.addAll(context.getContext().getFunctionInputTypes()); + List functionInputTypes = functionInputTypesBuilder.build(); FunctionType functionType = (FunctionType) process(node.getFunction(), new StackableAstVisitorContext<>(context.getContext().expectingLambda(functionInputTypes))); List argumentTypes = functionType.getArgumentTypes(); + int numCapturedValues = node.getValues().size(); verify(argumentTypes.size() == functionInputTypes.size()); - verify(functionInputTypes.get(0) == argumentTypes.get(0)); + for (int i = 0; i < numCapturedValues; i++) { + verify(functionInputTypes.get(i) == argumentTypes.get(i)); + } - FunctionType result = new FunctionType(argumentTypes.subList(1, argumentTypes.size()), functionType.getReturnType()); + FunctionType result = new FunctionType(argumentTypes.subList(numCapturedValues, argumentTypes.size()), functionType.getReturnType()); return setExpressionType(node, result); } @@ -1105,6 +1120,24 @@ protected Type visitNode(Node node, StackableAstVisitorContext context) throw new SemanticException(NOT_SUPPORTED, node, "not yet implemented: " + node.getClass().getName()); } + public Type visitGroupingOperation(GroupingOperation node, StackableAstVisitorContext context) + { + if (node.getGroupingColumns().size() > MAX_NUMBER_GROUPING_ARGUMENTS_BIGINT) { + throw new SemanticException(INVALID_PROCEDURE_ARGUMENTS, node, String.format("GROUPING supports up to %d column arguments", MAX_NUMBER_GROUPING_ARGUMENTS_BIGINT)); + } + + for (Expression columnArgument : node.getGroupingColumns()) { + process(columnArgument, context); + } + + if (node.getGroupingColumns().size() <= MAX_NUMBER_GROUPING_ARGUMENTS_INTEGER) { + return setExpressionType(node, INTEGER); + } + else { + return setExpressionType(node, BIGINT); + } + } + private Type getOperator(StackableAstVisitorContext context, Expression node, OperatorType operatorType, Expression... arguments) { ImmutableList.Builder argumentTypes = ImmutableList.builder(); @@ -1209,12 +1242,13 @@ private Type coerceToSingleType(StackableAstVisitorContext context, Str private void addOrReplaceExpressionCoercion(Expression expression, Type type, Type superType) { - expressionCoercions.put(expression, superType); + NodeRef ref = NodeRef.of(expression); + expressionCoercions.put(ref, superType); if (typeManager.isTypeOnlyCoercion(type, superType)) { - typeOnlyCoercions.add(expression); + typeOnlyCoercions.add(ref); } - else if (typeOnlyCoercions.contains(expression)) { - typeOnlyCoercions.remove(expression); + else if (typeOnlyCoercions.contains(ref)) { + typeOnlyCoercions.remove(ref); } } } @@ -1281,7 +1315,23 @@ public List getFunctionInputTypes() } } - public static IdentityLinkedHashMap getExpressionTypes( + public static Signature resolveFunction(FunctionCall node, List argumentTypes, FunctionRegistry functionRegistry) + { + try { + return functionRegistry.resolveFunction(node.getName(), argumentTypes); + } + catch (PrestoException e) { + if (e.getErrorCode().getCode() == StandardErrorCode.FUNCTION_NOT_FOUND.toErrorCode().getCode()) { + throw new SemanticException(SemanticErrorCode.FUNCTION_NOT_FOUND, node, e.getMessage()); + } + if (e.getErrorCode().getCode() == StandardErrorCode.AMBIGUOUS_FUNCTION_CALL.toErrorCode().getCode()) { + throw new SemanticException(SemanticErrorCode.AMBIGUOUS_FUNCTION_CALL, node, e.getMessage()); + } + throw e; + } + } + + public static Map, Type> getExpressionTypes( Session session, Metadata metadata, SqlParser sqlParser, @@ -1292,7 +1342,7 @@ public static IdentityLinkedHashMap getExpressionTypes( return getExpressionTypes(session, metadata, sqlParser, types, expression, parameters, false); } - public static IdentityLinkedHashMap getExpressionTypes( + public static Map, Type> getExpressionTypes( Session session, Metadata metadata, SqlParser sqlParser, @@ -1304,7 +1354,7 @@ public static IdentityLinkedHashMap getExpressionTypes( return getExpressionTypes(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, isDescribe); } - public static IdentityLinkedHashMap getExpressionTypes( + public static Map, Type> getExpressionTypes( Session session, Metadata metadata, SqlParser sqlParser, @@ -1316,7 +1366,7 @@ public static IdentityLinkedHashMap getExpressionTypes( return analyzeExpressionsWithSymbols(session, metadata, sqlParser, types, expressions, parameters, isDescribe).getExpressionTypes(); } - public static IdentityLinkedHashMap getExpressionTypesFromInput( + public static Map, Type> getExpressionTypesFromInput( Session session, Metadata metadata, SqlParser sqlParser, @@ -1327,7 +1377,7 @@ public static IdentityLinkedHashMap getExpressionTypesFromInpu return getExpressionTypesFromInput(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters); } - public static IdentityLinkedHashMap getExpressionTypesFromInput( + public static Map, Type> getExpressionTypesFromInput( Session session, Metadata metadata, SqlParser sqlParser, @@ -1421,10 +1471,10 @@ public static ExpressionAnalysis analyzeExpression( ExpressionAnalyzer analyzer = create(analysis, session, metadata, sqlParser, accessControl, ImmutableMap.of()); analyzer.analyze(expression, scope); - IdentityLinkedHashMap expressionTypes = analyzer.getExpressionTypes(); - IdentityLinkedHashMap expressionCoercions = analyzer.getExpressionCoercions(); - Set typeOnlyCoercions = analyzer.getTypeOnlyCoercions(); - IdentityLinkedHashMap resolvedFunctions = analyzer.getResolvedFunctions(); + Map, Type> expressionTypes = analyzer.getExpressionTypes(); + Map, Type> expressionCoercions = analyzer.getExpressionCoercions(); + Set> typeOnlyCoercions = analyzer.getTypeOnlyCoercions(); + Map, Signature> resolvedFunctions = analyzer.getResolvedFunctions(); analysis.addTypes(expressionTypes); analysis.addCoercions(expressionCoercions, typeOnlyCoercions); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java new file mode 100644 index 0000000000000..2da9a95ac3163 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer; + +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.Node; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.function.Predicate; + +import static com.google.common.base.Predicates.alwaysTrue; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +final class ExpressionTreeUtils +{ + private ExpressionTreeUtils() {} + + static List extractAggregateFunctions(Iterable nodes, FunctionRegistry functionRegistry) + { + return extractExpressions(nodes, FunctionCall.class, isAggregationPredicate(functionRegistry)); + } + + static List extractWindowFunctions(Iterable nodes) + { + return extractExpressions(nodes, FunctionCall.class, ExpressionTreeUtils::isWindowFunction); + } + + static List extractExpressions( + Iterable nodes, + Class clazz) + { + return extractExpressions(nodes, clazz, alwaysTrue()); + } + + private static Predicate isAggregationPredicate(FunctionRegistry functionRegistry) + { + return ((functionCall) -> (functionRegistry.isAggregationFunction(functionCall.getName()) || functionCall.getFilter().isPresent()) && !functionCall.getWindow().isPresent()); + } + + private static boolean isWindowFunction(FunctionCall functionCall) + { + return functionCall.getWindow().isPresent(); + } + + private static List extractExpressions( + Iterable nodes, + Class clazz, + Predicate predicate) + { + requireNonNull(nodes, "nodes is null"); + requireNonNull(clazz, "clazz is null"); + requireNonNull(predicate, "predicate is null"); + + return ImmutableList.copyOf(nodes).stream() + .flatMap(node -> linearizeNodes(node).stream()) + .filter(clazz::isInstance) + .map(clazz::cast) + .filter(predicate) + .collect(toImmutableList()); + } + + private static List linearizeNodes(Node node) + { + ImmutableList.Builder nodes = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor() + { + @Override + public Node process(Node node, Void context) + { + Node result = super.process(node, context); + nodes.add(node); + return result; + } + }.process(node, null); + return nodes.build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index ea9306e5611e2..8d8fb01bf6eef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -43,7 +43,7 @@ public class FeaturesConfig private boolean distributedJoinsEnabled = true; private boolean colocatedJoinsEnabled; private boolean fastInequalityJoins = true; - private boolean reorderJoins; + private boolean reorderJoins = true; private boolean redistributeWrites = true; private boolean optimizeMetadataQueries; private boolean optimizeHashGeneration = true; @@ -54,6 +54,7 @@ public class FeaturesConfig private boolean legacyArrayAgg; private boolean legacyOrderBy; private boolean legacyMapSubscript; + private boolean newMapBlock = true; private boolean optimizeMixedDistinctAggregations; private boolean dictionaryAggregation; @@ -68,6 +69,7 @@ public class FeaturesConfig private int spillerThreads = 4; private double spillMaxUsedSpaceThreshold = 0.9; private boolean iterativeOptimizerEnabled = true; + private boolean pushAggregationThroughJoin = true; private Duration iterativeOptimizerTimeout = new Duration(3, MINUTES); // by default let optimizer wait a long time in case it retrieves some data from ConnectorMetadata @@ -136,6 +138,18 @@ public boolean isLegacyMapSubscript() return legacyMapSubscript; } + @Config("deprecated.new-map-block") + public FeaturesConfig setNewMapBlock(boolean value) + { + this.newMapBlock = value; + return this; + } + + public boolean isNewMapBlock() + { + return newMapBlock; + } + @Config("distributed-joins-enabled") public FeaturesConfig setDistributedJoinsEnabled(boolean distributedJoinsEnabled) { @@ -157,7 +171,7 @@ public FeaturesConfig setColocatedJoinsEnabled(boolean colocatedJoinsEnabled) } @Config("fast-inequality-joins") - @ConfigDescription("Experimental: Use faster handling of inequality joins if it is possible") + @ConfigDescription("Use faster handling of inequality joins if it is possible") public FeaturesConfig setFastInequalityJoins(boolean fastInequalityJoins) { this.fastInequalityJoins = fastInequalityJoins; @@ -412,4 +426,16 @@ public FeaturesConfig setEnableIntermediateAggregations(boolean enableIntermedia this.enableIntermediateAggregations = enableIntermediateAggregations; return this; } + + public boolean isPushAggregationThroughJoin() + { + return pushAggregationThroughJoin; + } + + @Config("optimizer.push-aggregation-through-join") + public FeaturesConfig setPushAggregationThroughJoin(boolean value) + { + this.pushAggregationThroughJoin = value; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FreeLambdaReferenceExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FreeLambdaReferenceExtractor.java new file mode 100644 index 0000000000000..fa758c8edbf3b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FreeLambdaReferenceExtractor.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer; + +import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; +import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +/** + * Extract expressions that are free (unbound) references to a lambda argument. + */ +public class FreeLambdaReferenceExtractor +{ + private FreeLambdaReferenceExtractor() {} + + public static boolean hasFreeReferencesToLambdaArgument(Node node, Analysis analysis) + { + return !getFreeReferencesToLambdaArgument(node, analysis).isEmpty(); + } + + public static List getFreeReferencesToLambdaArgument(Node node, Analysis analysis) + { + Visitor visitor = new Visitor(analysis); + visitor.process(node, ImmutableSet.of()); + return visitor.getFreeReferencesToLambdaArgument(); + } + + private static class Visitor + extends DefaultExpressionTraversalVisitor> + { + private final Analysis analysis; + private final ImmutableList.Builder freeReferencesToLambdaArgument = ImmutableList.builder(); + + private Visitor(Analysis analysis) + { + this.analysis = requireNonNull(analysis, "analysis is null"); + } + + List getFreeReferencesToLambdaArgument() + { + return freeReferencesToLambdaArgument.build(); + } + + @Override + protected Void visitIdentifier(Identifier node, Set lambdaArgumentNames) + { + if (analysis.getLambdaArgumentReferences().containsKey(NodeRef.of(node)) && !lambdaArgumentNames.contains(node.getName())) { + freeReferencesToLambdaArgument.add(node); + } + return null; + } + + @Override + protected Void visitLambdaExpression(LambdaExpression node, Set lambdaArgumentNames) + { + return process(node.getBody(), ImmutableSet.builder() + .addAll(lambdaArgumentNames) + .addAll(node.getArguments().stream() + .map(LambdaArgumentDeclaration::getName) + .collect(toImmutableSet())) + .build()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/LambdaReferenceExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/LambdaReferenceExtractor.java deleted file mode 100644 index ff9a0106fd4a1..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/LambdaReferenceExtractor.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.analyzer; - -import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.Identifier; -import com.facebook.presto.sql.tree.Node; -import com.google.common.collect.ImmutableList; - -import java.util.List; - -import static java.util.Objects.requireNonNull; - -/** - * Extract expressions that are references to a lambda argument. - */ -public class LambdaReferenceExtractor -{ - private LambdaReferenceExtractor() {} - - public static boolean hasReferencesToLambdaArgument(Node node, Analysis analysis) - { - return !getReferencesToLambdaArgument(node, analysis).isEmpty(); - } - - public static List getReferencesToLambdaArgument(Node node, Analysis analysis) - { - ImmutableList.Builder builder = ImmutableList.builder(); - new Visitor(analysis).process(node, builder); - return builder.build(); - } - - private static class Visitor - extends DefaultExpressionTraversalVisitor> - { - private final Analysis analysis; - - private Visitor(Analysis analysis) - { - this.analysis = requireNonNull(analysis, "analysis is null"); - } - - @Override - protected Void visitIdentifier(Identifier node, ImmutableList.Builder context) - { - if (analysis.getLambdaArgumentReferences().containsKey(node)) { - context.add(node); - } - return null; - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java index ae9d51e0f0000..28236a161f95e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.DataDefinitionTask; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; @@ -45,6 +46,7 @@ public class QueryExplainer private final Metadata metadata; private final AccessControl accessControl; private final SqlParser sqlParser; + private final CostCalculator costCalculator; private final Map, DataDefinitionTask> dataDefinitionTask; @Inject @@ -53,12 +55,14 @@ public QueryExplainer( Metadata metadata, AccessControl accessControl, SqlParser sqlParser, + CostCalculator costCalculator, Map, DataDefinitionTask> dataDefinitionTask) { this(planOptimizers.get(), metadata, accessControl, sqlParser, + costCalculator, dataDefinitionTask); } @@ -67,12 +71,14 @@ public QueryExplainer( Metadata metadata, AccessControl accessControl, SqlParser sqlParser, + CostCalculator costCalculator, Map, DataDefinitionTask> dataDefinitionTask) { this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.dataDefinitionTask = ImmutableMap.copyOf(requireNonNull(dataDefinitionTask, "dataDefinitionTask is null")); } @@ -92,10 +98,10 @@ public String getPlan(Session session, Statement statement, Type planType, List< switch (planType) { case LOGICAL: Plan plan = getLogicalPlan(session, statement, parameters); - return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, session); + return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, costCalculator, session); case DISTRIBUTED: SubPlan subPlan = getDistributedPlan(session, statement, parameters); - return PlanPrinter.textDistributedPlan(subPlan, metadata, session); + return PlanPrinter.textDistributedPlan(subPlan, metadata, costCalculator, session); } throw new IllegalArgumentException("Unhandled plan type: " + planType); } @@ -124,7 +130,7 @@ public String getGraphvizPlan(Session session, Statement statement, Type planTyp throw new IllegalArgumentException("Unhandled plan type: " + planType); } - private Plan getLogicalPlan(Session session, Statement statement, List parameters) + public Plan getLogicalPlan(Session session, Statement statement, List parameters) { // analyze statement Analysis analysis = analyze(session, statement, parameters); @@ -132,7 +138,7 @@ private Plan getLogicalPlan(Session session, Statement statement, List getReferencesToScope(Node node, Analysis analysis, Scope scope) { - Map columnReferences = analysis.getColumnReferenceFields(); + Map, FieldId> columnReferences = analysis.getColumnReferenceFields(); return AstUtils.preOrder(node) - .filter(columnReferences::containsKey) + .filter(Expression.class::isInstance) .map(Expression.class::cast) + .filter(expression -> columnReferences.containsKey(NodeRef.of(expression))) .filter(expression -> isReferenceToScope(expression, scope, columnReferences)); } - private static boolean isReferenceToScope(Expression node, Scope scope, Map columnReferences) + private static boolean isReferenceToScope(Expression node, Scope scope, Map, FieldId> columnReferences) { - FieldId fieldId = columnReferences.get(node); + FieldId fieldId = columnReferences.get(NodeRef.of(node)); requireNonNull(fieldId, () -> "No FieldId for " + node); return isFieldFromScope(fieldId, scope); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java index 3f0f16e8ac066..7b01411b9d1e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java @@ -50,10 +50,11 @@ public enum SemanticErrorCode FUNCTION_NOT_FOUND, ORDER_BY_MUST_BE_IN_SELECT, + REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING, REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION, NONDETERMINISTIC_ORDER_BY_EXPRESSION_WITH_SELECT_DISTINCT, - CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, + CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, WILDCARD_WITHOUT_FROM, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index f0f9a2e591b64..1efaee7edc683 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -29,13 +29,16 @@ import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.parser.ParsingException; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.ExpressionInterpreter; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.tree.AddColumn; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.AllColumns; @@ -65,6 +68,7 @@ import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.Grant; import com.facebook.presto.sql.tree.GroupingElement; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.Insert; import com.facebook.presto.sql.tree.Intersect; @@ -72,9 +76,11 @@ import com.facebook.presto.sql.tree.JoinCriteria; import com.facebook.presto.sql.tree.JoinOn; import com.facebook.presto.sql.tree.JoinUsing; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.NaturalJoin; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.Prepare; import com.facebook.presto.sql.tree.QualifiedName; @@ -107,10 +113,6 @@ import com.facebook.presto.sql.tree.With; import com.facebook.presto.sql.tree.WithQuery; import com.facebook.presto.sql.util.AstUtils; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -143,6 +145,9 @@ import static com.facebook.presto.sql.analyzer.AggregationAnalyzer.verifySourceAggregations; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractExpressions; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractWindowFunctions; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.hasReferencesToScope; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.AMBIGUOUS_ATTRIBUTE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.COLUMN_NAME_NOT_SPECIFIED; @@ -150,6 +155,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_WINDOW_FRAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_COLUMN_ALIASES; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_SET_COLUMN_TYPES; @@ -189,6 +195,7 @@ import static com.google.common.collect.Iterables.transform; import static java.lang.Math.toIntExact; import static java.util.Collections.emptyList; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; class StatementAnalyzer @@ -640,6 +647,14 @@ else if (expressionType instanceof MapType) { return createAndAssignScope(node, scope, outputFields.build()); } + @Override + protected Scope visitLateral(Lateral node, Optional scope) + { + StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session); + Scope queryScope = analyzer.analyze(node.getQuery(), scope); + return createAndAssignScope(node, scope, queryScope.getRelationType()); + } + @Override protected Scope visitTable(Table table, Optional scope) { @@ -798,11 +813,11 @@ protected Scope visitAliasedRelation(AliasedRelation relation, Optional s @Override protected Scope visitSampledRelation(SampledRelation relation, Optional scope) { - if (!DependencyExtractor.extractNames(relation.getSamplePercentage(), analysis.getColumnReferences()).isEmpty()) { + if (!SymbolsExtractor.extractNames(relation.getSamplePercentage(), analysis.getColumnReferences()).isEmpty()) { throw new SemanticException(NON_NUMERIC_SAMPLE_PERCENTAGE, relation.getSamplePercentage(), "Sample percentage cannot contain column references"); } - IdentityLinkedHashMap expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, @@ -876,6 +891,7 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional sourceExpressions.addAll(outputExpressions); node.getHaving().ifPresent(sourceExpressions::add); + analyzeGroupingOperations(node, sourceExpressions, orderByExpressions); List aggregations = analyzeAggregations(node, sourceScope, orderByScope, groupByExpressions, sourceExpressions, orderByExpressions); analyzeWindowFunctions(node, outputExpressions, orderByExpressions); @@ -885,7 +901,7 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional // Original ORDER BY scope "sees" FROM query fields. However, during planning // and when aggregation is present, ORDER BY expressions should only be resolvable against // output scope, group by expressions and aggregation expressions. - computeAndAssignOrderByScopeWithAggregation(node.getOrderBy().get(), outputScope, aggregations, groupByExpressions); + computeAndAssignOrderByScopeWithAggregation(node.getOrderBy().get(), outputScope, aggregations, groupByExpressions, analysis.getGroupingOperations(node)); } return outputScope; @@ -920,6 +936,7 @@ private Scope legacyVisitQuerySpecification(QuerySpecification node, Optional scope) int outputFieldSize = outputFieldTypes.length; RelationType relationType = relationScope.getRelationType(); int descFieldSize = relationType.getVisibleFields().size(); - String setOperationName = node.getClass().getSimpleName(); + String setOperationName = node.getClass().getSimpleName().toUpperCase(ENGLISH); if (outputFieldSize != descFieldSize) { throw new SemanticException(MISMATCHED_SET_COLUMN_TYPES, node, @@ -959,7 +976,7 @@ protected Scope visitSetOperation(SetOperation node, Optional scope) throw new SemanticException(TYPE_MISMATCH, node, "column %d in %s query has incompatible types: %s, %s", - i, outputFieldTypes[i].getDisplayName(), setOperationName, descFieldType.getDisplayName()); + i, setOperationName, outputFieldTypes[i].getDisplayName(), descFieldType.getDisplayName()); } outputFieldTypes[i] = commonSuperType.get(); } @@ -1023,7 +1040,7 @@ protected Scope visitJoin(Join node, Optional scope) } Scope left = process(node.getLeft(), scope); - Scope right = process(node.getRight(), isUnnestRelation(node.getRight()) ? Optional.of(left) : scope); + Scope right = process(node.getRight(), isLateralRelation(node.getRight()) ? Optional.of(left) : scope); Scope output = createAndAssignScope(node, scope, left.getRelationType().joinWith(right.getRelationType())); @@ -1067,7 +1084,7 @@ else if (criteria instanceof JoinOn) { analysis.addCoercion(expression, BOOLEAN, false); } - Analyzer.verifyNoAggregatesOrWindowFunctions(metadata.getFunctionRegistry(), expression, "JOIN clause"); + Analyzer.verifyNoAggregateWindowOrGroupingFunctions(metadata.getFunctionRegistry(), expression, "JOIN clause"); analysis.recordSubqueries(node, expressionAnalysis); analysis.setJoinCriteria(node, expression); @@ -1079,12 +1096,12 @@ else if (criteria instanceof JoinOn) { return output; } - private boolean isUnnestRelation(Relation node) + private boolean isLateralRelation(Relation node) { if (node instanceof AliasedRelation) { - return isUnnestRelation(((AliasedRelation) node).getRelation()); + return isLateralRelation(((AliasedRelation) node).getRelation()); } - return node instanceof Unnest; + return node instanceof Unnest || node instanceof Lateral; } private void addCoercionForJoinCriteria(Join node, Expression leftExpression, Expression rightExpression) @@ -1185,14 +1202,11 @@ private void analyzeWindowFunctions(QuerySpecification node, List ou private List analyzeWindowFunctions(QuerySpecification node, List expressions) { - WindowFunctionExtractor extractor = new WindowFunctionExtractor(); - for (Expression expression : expressions) { - extractor.process(expression, null); new WindowFunctionValidator().process(expression, analysis); } - List windowFunctions = extractor.getWindowFunctions(); + List windowFunctions = extractWindowFunctions(expressions); for (FunctionCall windowFunction : windowFunctions) { // filter with window function is not supported yet @@ -1202,24 +1216,15 @@ private List analyzeWindowFunctions(QuerySpecification node, List< Window window = windowFunction.getWindow().get(); - WindowFunctionExtractor nestedExtractor = new WindowFunctionExtractor(); - for (Expression argument : windowFunction.getArguments()) { - nestedExtractor.process(argument, null); - } - - for (Expression expression : window.getPartitionBy()) { - nestedExtractor.process(expression, null); - } + ImmutableList.Builder toExtract = ImmutableList.builder(); + toExtract.addAll(windowFunction.getArguments()); + toExtract.addAll(window.getPartitionBy()); + window.getOrderBy().ifPresent(orderBy -> toExtract.addAll(orderBy.getSortItems())); + window.getFrame().ifPresent(toExtract::add); - if (window.getOrderBy().isPresent()) { - nestedExtractor.process(window.getOrderBy().get(), null); - } + List nestedWindowFunctions = extractWindowFunctions(toExtract.build()); - if (window.getFrame().isPresent()) { - nestedExtractor.process(window.getFrame().get(), null); - } - - if (!nestedExtractor.getWindowFunctions().isEmpty()) { + if (!nestedWindowFunctions.isEmpty()) { throw new SemanticException(NESTED_WINDOW, node, "Cannot nest window functions inside window function '%s': %s", windowFunction, windowFunctions); @@ -1503,7 +1508,7 @@ private List analyzeGroupingColumns(Set groupingColumns, groupByExpression = groupingColumn; } - Analyzer.verifyNoAggregatesOrWindowFunctions(metadata.getFunctionRegistry(), groupByExpression, "GROUP BY clause"); + Analyzer.verifyNoAggregateWindowOrGroupingFunctions(metadata.getFunctionRegistry(), groupByExpression, "GROUP BY clause"); Type type = analysis.getType(groupByExpression); if (!type.isComparable()) { throw new SemanticException(TYPE_MISMATCH, node, "%s is not comparable, and therefore cannot be used in GROUP BY", type); @@ -1577,7 +1582,7 @@ private Scope computeAndAssignOrderByScope(OrderBy node, Scope sourceScope, Scop return orderByScope; } - private Scope computeAndAssignOrderByScopeWithAggregation(OrderBy node, Scope outputScope, List aggregations, List> groupByExpressions) + private Scope computeAndAssignOrderByScopeWithAggregation(OrderBy node, Scope outputScope, List aggregations, List> groupByExpressions, List groupingOperations) { // This scope is only used for planning. When aggregation is present then // only output fields, groups and aggregation expressions should be visible from ORDER BY expression @@ -1586,6 +1591,7 @@ private Scope computeAndAssignOrderByScopeWithAggregation(OrderBy node, Scope ou .flatMap(List::stream) .forEach(orderByAggregationExpressionsBuilder::add); orderByAggregationExpressionsBuilder.addAll(aggregations); + orderByAggregationExpressionsBuilder.addAll(groupingOperations); // Don't add aggregate expression that contains references to output column because the names would clash in TranslationMap during planning. List orderByExpressionsReferencingOutputScope = AstUtils.preOrder(node) @@ -1670,7 +1676,7 @@ else if (item instanceof SingleColumn) { public void analyzeWhere(Node node, Scope scope, Expression predicate) { - Analyzer.verifyNoAggregatesOrWindowFunctions(metadata.getFunctionRegistry(), predicate, "WHERE clause"); + Analyzer.verifyNoAggregateWindowOrGroupingFunctions(metadata.getFunctionRegistry(), predicate, "WHERE clause"); ExpressionAnalysis expressionAnalysis = analyzeExpression(predicate, scope); analysis.recordSubqueries(node, expressionAnalysis); @@ -1696,6 +1702,21 @@ private Scope analyzeFrom(QuerySpecification node, Optional scope) return createScope(scope); } + private void analyzeGroupingOperations(QuerySpecification node, List outputExpressions, List orderByExpressions) + { + List groupingOperations = extractExpressions(Iterables.concat(outputExpressions, orderByExpressions), GroupingOperation.class); + boolean isGroupingOperationPresent = !groupingOperations.isEmpty(); + + if (isGroupingOperationPresent && !node.getGroupBy().isPresent()) { + throw new SemanticException( + INVALID_PROCEDURE_ARGUMENTS, + node, + "A GROUPING() operation can only be used with a corresponding GROUPING SET/CUBE/ROLLUP/GROUP BY clause"); + } + + analysis.setGroupingOperations(node, groupingOperations); + } + private List analyzeAggregations( QuerySpecification node, Scope sourceScope, @@ -1706,11 +1727,8 @@ private List analyzeAggregations( { checkState(orderByExpressions.isEmpty() || orderByScope.isPresent(), "non-empty orderByExpressions list without orderByScope provided"); - AggregateExtractor extractor = new AggregateExtractor(metadata.getFunctionRegistry()); - for (Expression expression : Iterables.concat(outputExpressions, orderByExpressions)) { - extractor.process(expression); - } - analysis.setAggregates(node, extractor.getAggregates()); + List aggregates = extractAggregateFunctions(Iterables.concat(outputExpressions, orderByExpressions), metadata.getFunctionRegistry()); + analysis.setAggregates(node, aggregates); // is this an aggregation query? if (!groupingSets.isEmpty()) { @@ -1733,24 +1751,24 @@ private List analyzeAggregations( } } - return extractor.getAggregates(); + return aggregates; } private boolean hasAggregates(QuerySpecification node) { - AggregateExtractor extractor = new AggregateExtractor(metadata.getFunctionRegistry()); + ImmutableList.Builder toExtractBuilder = ImmutableList.builder(); - node.getSelect() - .getSelectItems().stream() + toExtractBuilder.addAll(node.getSelect().getSelectItems().stream() .filter(SingleColumn.class::isInstance) - .forEach(extractor::process); + .collect(toImmutableList())); + + toExtractBuilder.addAll(getSortItemsFromOrderBy(node.getOrderBy())); - getSortItemsFromOrderBy(node.getOrderBy()).forEach(extractor::process); + node.getHaving().ifPresent(toExtractBuilder::add); - node.getHaving() - .ifPresent(extractor::process); + List aggregates = extractAggregateFunctions(toExtractBuilder.build(), metadata.getFunctionRegistry()); - return !extractor.getAggregates().isEmpty(); + return !aggregates.isEmpty(); } private RelationType analyzeView(Query query, QualifiedObjectName name, Optional catalog, Optional schema, Optional owner, Table node) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/WindowFunctionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/WindowFunctionExtractor.java deleted file mode 100644 index 9ba25f3bcb5b0..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/WindowFunctionExtractor.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.analyzer; - -import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; -import com.facebook.presto.sql.tree.FunctionCall; -import com.google.common.collect.ImmutableList; - -import java.util.List; - -class WindowFunctionExtractor - extends DefaultExpressionTraversalVisitor -{ - private final ImmutableList.Builder windowFunctions = ImmutableList.builder(); - - @Override - protected Void visitFunctionCall(FunctionCall node, Void context) - { - if (node.getWindow().isPresent()) { - windowFunctions.add(node); - return null; - } - - return super.visitFunctionCall(node, null); - } - - public List getWindowFunctions() - { - return windowFunctions.build(); - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java index 0f30cd1add683..662457516c146 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BindCodeGenerator.java @@ -14,63 +14,46 @@ package com.facebook.presto.sql.gen; -import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; -import com.facebook.presto.bytecode.Scope; -import com.facebook.presto.bytecode.Variable; -import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda; +import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Primitives; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; import java.util.List; +import java.util.Map; -import static com.facebook.presto.bytecode.ParameterizedType.type; -import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse; -import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; -import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic; -import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; -import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; +import static com.google.common.base.Preconditions.checkState; public class BindCodeGenerator implements BytecodeGenerator { + private Map compiledLambdaMap; + private Class lambdaInterface; + + public BindCodeGenerator(Map compiledLambdaMap, Class lambdaInterface) + { + this.compiledLambdaMap = compiledLambdaMap; + this.lambdaInterface = lambdaInterface; + } + @Override public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type returnType, List arguments) { - BytecodeBlock block = new BytecodeBlock().setDescription("Partial apply"); - Scope scope = context.getScope(); - - Variable wasNull = scope.getVariable("wasNull"); - - Class valueType = Primitives.wrap(arguments.get(0).getType().getJavaType()); - Variable valueVariable = scope.createTempVariable(valueType); - block.append(context.generate(arguments.get(0))); - block.append(boxPrimitiveIfNecessary(scope, valueType)); - block.putVariable(valueVariable); - block.append(wasNull.set(constantFalse())); - - Variable functionVariable = scope.createTempVariable(MethodHandle.class); - block.append(context.generate(arguments.get(1))); - block.append( - new IfStatement() - .condition(wasNull) - // ifTrue: do nothing i.e. Leave the null MethodHandle on the stack, and leave the wasNull variable set to true - .ifFalse( - new BytecodeBlock() - .putVariable(functionVariable) - .append(invokeStatic( - MethodHandles.class, - "insertArguments", - MethodHandle.class, - functionVariable, - constantInt(0), - newArray(type(Object[].class), ImmutableList.of(valueVariable.cast(Object.class))))))); - - return block; + // Bind expression is used to generate captured lambda. + // It takes the captured values and the uncaptured lambda, and produces captured lambda as the output. + // The uncaptured lambda is just a method, and does not have a stack representation during execution. + // As a result, the bind expression generates the captured lambda in one step. + int numCaptures = arguments.size() - 1; + LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) arguments.get(numCaptures); + checkState(compiledLambdaMap.containsKey(lambda), "lambda expressions map does not contain this lambda definition"); + CompiledLambda compiledLambda = compiledLambdaMap.get(lambda); + + return LambdaBytecodeGenerator.generateLambda( + context, + arguments.subList(0, numCaptures), + compiledLambda, + lambdaInterface); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BodyCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BodyCompiler.java index 22e7fb679dbed..4d101195dc5bf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BodyCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BodyCompiler.java @@ -18,7 +18,7 @@ import java.util.List; -public interface BodyCompiler +public interface BodyCompiler { void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, List projections); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeExpressionVisitor.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeExpressionVisitor.java deleted file mode 100644 index 65301c444d868..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeExpressionVisitor.java +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.gen; - -import com.facebook.presto.bytecode.BytecodeBlock; -import com.facebook.presto.bytecode.BytecodeNode; -import com.facebook.presto.bytecode.Scope; -import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.sql.relational.CallExpression; -import com.facebook.presto.sql.relational.ConstantExpression; -import com.facebook.presto.sql.relational.InputReferenceExpression; -import com.facebook.presto.sql.relational.LambdaDefinitionExpression; -import com.facebook.presto.sql.relational.RowExpressionVisitor; -import com.facebook.presto.sql.relational.VariableReferenceExpression; - -import java.lang.invoke.MethodHandle; - -import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; -import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; -import static com.facebook.presto.bytecode.instruction.Constant.loadBoolean; -import static com.facebook.presto.bytecode.instruction.Constant.loadDouble; -import static com.facebook.presto.bytecode.instruction.Constant.loadFloat; -import static com.facebook.presto.bytecode.instruction.Constant.loadInt; -import static com.facebook.presto.bytecode.instruction.Constant.loadLong; -import static com.facebook.presto.bytecode.instruction.Constant.loadString; -import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant; -import static com.facebook.presto.sql.relational.Signatures.BIND; -import static com.facebook.presto.sql.relational.Signatures.CAST; -import static com.facebook.presto.sql.relational.Signatures.COALESCE; -import static com.facebook.presto.sql.relational.Signatures.DEREFERENCE; -import static com.facebook.presto.sql.relational.Signatures.IF; -import static com.facebook.presto.sql.relational.Signatures.IN; -import static com.facebook.presto.sql.relational.Signatures.IS_NULL; -import static com.facebook.presto.sql.relational.Signatures.NULL_IF; -import static com.facebook.presto.sql.relational.Signatures.ROW_CONSTRUCTOR; -import static com.facebook.presto.sql.relational.Signatures.SWITCH; -import static com.facebook.presto.sql.relational.Signatures.TRY; -import static com.google.common.base.Preconditions.checkState; - -public class BytecodeExpressionVisitor - implements RowExpressionVisitor -{ - private final CallSiteBinder callSiteBinder; - private final CachedInstanceBinder cachedInstanceBinder; - private final RowExpressionVisitor fieldReferenceCompiler; - private final FunctionRegistry registry; - private final PreGeneratedExpressions preGeneratedExpressions; - - public BytecodeExpressionVisitor( - CallSiteBinder callSiteBinder, - CachedInstanceBinder cachedInstanceBinder, - RowExpressionVisitor fieldReferenceCompiler, - FunctionRegistry registry, - PreGeneratedExpressions preGeneratedExpressions) - { - this.callSiteBinder = callSiteBinder; - this.cachedInstanceBinder = cachedInstanceBinder; - this.fieldReferenceCompiler = fieldReferenceCompiler; - this.registry = registry; - this.preGeneratedExpressions = preGeneratedExpressions; - } - - @Override - public BytecodeNode visitCall(CallExpression call, final Scope scope) - { - BytecodeGenerator generator; - // special-cased in function registry - if (call.getSignature().getName().equals(CAST)) { - generator = new CastCodeGenerator(); - } - else { - switch (call.getSignature().getName()) { - // lazy evaluation - case IF: - generator = new IfCodeGenerator(); - break; - case NULL_IF: - generator = new NullIfCodeGenerator(); - break; - case SWITCH: - // (SWITCH (WHEN ) (WHEN ) ) - generator = new SwitchCodeGenerator(); - break; - case TRY: - generator = new TryCodeGenerator(preGeneratedExpressions.getTryMethodMap()); - break; - // functions that take null as input - case IS_NULL: - generator = new IsNullCodeGenerator(); - break; - case COALESCE: - generator = new CoalesceCodeGenerator(); - break; - // functions that require varargs and/or complex types (e.g., lists) - case IN: - generator = new InCodeGenerator(registry); - break; - // optimized implementations (shortcircuiting behavior) - case "AND": - generator = new AndCodeGenerator(); - break; - case "OR": - generator = new OrCodeGenerator(); - break; - case DEREFERENCE: - generator = new DereferenceCodeGenerator(); - break; - case ROW_CONSTRUCTOR: - generator = new RowConstructorCodeGenerator(); - break; - case BIND: - generator = new BindCodeGenerator(); - break; - default: - generator = new FunctionCallCodeGenerator(); - } - } - - BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext( - this, - scope, - callSiteBinder, - cachedInstanceBinder, - registry); - - return generator.generateExpression(call.getSignature(), generatorContext, call.getType(), call.getArguments()); - } - - @Override - public BytecodeNode visitConstant(ConstantExpression constant, Scope scope) - { - Object value = constant.getValue(); - Class javaType = constant.getType().getJavaType(); - - BytecodeBlock block = new BytecodeBlock(); - if (value == null) { - return block.comment("constant null") - .append(scope.getVariable("wasNull").set(constantTrue())) - .pushJavaDefault(javaType); - } - - // use LDC for primitives (boolean, short, int, long, float, double) - block.comment("constant " + constant.getType().getTypeSignature()); - if (javaType == boolean.class) { - return block.append(loadBoolean((Boolean) value)); - } - if (javaType == byte.class || javaType == short.class || javaType == int.class) { - return block.append(loadInt(((Number) value).intValue())); - } - if (javaType == long.class) { - return block.append(loadLong((Long) value)); - } - if (javaType == float.class) { - return block.append(loadFloat((Float) value)); - } - if (javaType == double.class) { - return block.append(loadDouble((Double) value)); - } - if (javaType == String.class) { - return block.append(loadString((String) value)); - } - if (javaType == void.class) { - return block; - } - - // bind constant object directly into the call-site using invoke dynamic - Binding binding = callSiteBinder.bind(value, constant.getType().getJavaType()); - - return new BytecodeBlock() - .setDescription("constant " + constant.getType()) - .comment(constant.toString()) - .append(loadConstant(binding)); - } - - @Override - public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope) - { - return fieldReferenceCompiler.visitInputReference(node, scope); - } - - @Override - public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope scope) - { - checkState(preGeneratedExpressions.getLambdaFieldMap().containsKey(lambda), "lambda expressions map does not contain this lambda definition"); - - return getStatic(preGeneratedExpressions.getLambdaFieldMap().get(lambda)) - .invoke("bindTo", MethodHandle.class, scope.getThis().cast(Object.class)) - .invoke("bindTo", MethodHandle.class, scope.getVariable("session").cast(Object.class)); - } - - @Override - public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope scope) - { - return fieldReferenceCompiler.visitVariableReference(reference, scope); - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java index d3af4960223f2..4abd084e2f3dc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeGeneratorContext.java @@ -29,31 +29,34 @@ public class BytecodeGeneratorContext { - private final BytecodeExpressionVisitor bytecodeGenerator; + private final RowExpressionCompiler rowExpressionCompiler; private final Scope scope; private final CallSiteBinder callSiteBinder; private final CachedInstanceBinder cachedInstanceBinder; private final FunctionRegistry registry; + private final PreGeneratedExpressions preGeneratedExpressions; private final Variable wasNull; public BytecodeGeneratorContext( - BytecodeExpressionVisitor bytecodeGenerator, + RowExpressionCompiler rowExpressionCompiler, Scope scope, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, - FunctionRegistry registry) + FunctionRegistry registry, + PreGeneratedExpressions preGeneratedExpressions) { - requireNonNull(bytecodeGenerator, "bytecodeGenerator is null"); + requireNonNull(rowExpressionCompiler, "bytecodeGenerator is null"); requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null"); requireNonNull(scope, "scope is null"); requireNonNull(callSiteBinder, "callSiteBinder is null"); requireNonNull(registry, "registry is null"); - this.bytecodeGenerator = bytecodeGenerator; + this.rowExpressionCompiler = rowExpressionCompiler; this.scope = scope; this.callSiteBinder = callSiteBinder; this.cachedInstanceBinder = cachedInstanceBinder; this.registry = registry; + this.preGeneratedExpressions = preGeneratedExpressions; this.wasNull = scope.getVariable("wasNull"); } @@ -69,7 +72,12 @@ public CallSiteBinder getCallSiteBinder() public BytecodeNode generate(RowExpression expression) { - return expression.accept(bytecodeGenerator, scope); + return generate(expression, Optional.empty()); + } + + public BytecodeNode generate(RowExpression expression, Optional lambdaInterface) + { + return rowExpressionCompiler.compile(expression, scope, lambdaInterface); } public FunctionRegistry getRegistry() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java index 0a2bf8d345507..3c25631105e75 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java @@ -16,7 +16,6 @@ import com.facebook.presto.bytecode.BytecodeBlock; import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.ClassDefinition; -import com.facebook.presto.bytecode.FieldDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.bytecode.Scope; @@ -25,12 +24,12 @@ import com.facebook.presto.bytecode.control.IfStatement; import com.facebook.presto.bytecode.instruction.LabelNode; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.operator.CursorProcessor; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.ConstantExpression; import com.facebook.presto.sql.relational.InputReferenceExpression; @@ -46,6 +45,7 @@ import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; +import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -61,7 +61,7 @@ import static java.lang.String.format; public class CursorProcessorCompiler - implements BodyCompiler + implements BodyCompiler { private final Metadata metadata; @@ -74,11 +74,19 @@ public CursorProcessorCompiler(Metadata metadata) public void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, List projections) { CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); + List allPreGeneratedExpressions = new ArrayList<>(projections.size() + 1); + generateProcessMethod(classDefinition, projections.size()); - generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter); + + PreGeneratedExpressions filterPreGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, filter, "filter"); + allPreGeneratedExpressions.add(filterPreGeneratedExpressions); + generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filterPreGeneratedExpressions, filter); for (int i = 0; i < projections.size(); i++) { - generateProjectMethod(classDefinition, callSiteBinder, cachedInstanceBinder, "project_" + i, projections.get(i)); + String methodName = "project_" + i; + PreGeneratedExpressions projectPreGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, projections.get(i), methodName); + allPreGeneratedExpressions.add(projectPreGeneratedExpressions); + generateProjectMethod(classDefinition, callSiteBinder, cachedInstanceBinder, projectPreGeneratedExpressions, methodName, projections.get(i)); } MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC)); @@ -87,7 +95,13 @@ public void generateMethods(ClassDefinition classDefinition, CallSiteBinder call constructorBody.comment("super();") .append(thisVariable) .invokeConstructor(Object.class); + cachedInstanceBinder.generateInitializations(thisVariable, constructorBody); + for (PreGeneratedExpressions preGeneratedExpressions : allPreGeneratedExpressions) { + for (CompiledLambda compiledLambda : preGeneratedExpressions.getCompiledLambdaMap().values()) { + compiledLambda.generateInitialization(thisVariable, constructorBody); + } + } constructorBody.ret(); } @@ -192,7 +206,7 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( Set lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(projection)); ImmutableMap.Builder tryMethodMap = ImmutableMap.builder(); - ImmutableMap.Builder lambdaFieldMap = ImmutableMap.builder(); + ImmutableMap.Builder compiledLambdaMap = ImmutableMap.builder(); int counter = 0; for (RowExpression expression : lambdaAndTryExpressions) { @@ -208,15 +222,15 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( .add(cursor) .build(); - BytecodeExpressionVisitor innerExpressionVisitor = new BytecodeExpressionVisitor( + RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(cursor), metadata.getFunctionRegistry(), - new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build())); + new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build())); MethodDefinition tryMethod = defineTryMethod( - innerExpressionVisitor, + innerExpressionCompiler, containerClassDefinition, methodPrefix + "_try_" + counter, inputParameters, @@ -229,8 +243,8 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( else if (expression instanceof LambdaDefinitionExpression) { LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) expression; String fieldName = methodPrefix + "_lambda_" + counter; - PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); - FieldDefinition methodHandleField = LambdaBytecodeGenerator.preGenerateLambdaExpression( + PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build()); + CompiledLambda compiledLambda = LambdaBytecodeGenerator.preGenerateLambdaExpression( lambdaExpression, fieldName, containerClassDefinition, @@ -238,7 +252,7 @@ else if (expression instanceof LambdaDefinitionExpression) { callSiteBinder, cachedInstanceBinder, metadata.getFunctionRegistry()); - lambdaFieldMap.put(lambdaExpression, methodHandleField); + compiledLambdaMap.put(lambdaExpression, compiledLambda); } else { throw new VerifyException(format("unexpected expression: %s", expression.toString())); @@ -246,13 +260,16 @@ else if (expression instanceof LambdaDefinitionExpression) { counter++; } - return new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); + return new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build()); } - private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpression filter) + private void generateFilterMethod( + ClassDefinition classDefinition, + CallSiteBinder callSiteBinder, + CachedInstanceBinder cachedInstanceBinder, + PreGeneratedExpressions preGeneratedExpressions, + RowExpression filter) { - PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, filter, "filter"); - Parameter session = arg("session", ConnectorSession.class); Parameter cursor = arg("cursor", RecordCursor.class); MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "filter", type(boolean.class), session, cursor); @@ -262,7 +279,7 @@ private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinde Scope scope = method.getScope(); Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); - BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor( + RowExpressionCompiler compiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(cursor), @@ -274,7 +291,7 @@ private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinde .comment("boolean wasNull = false;") .putVariable(wasNullVariable, false) .comment("evaluate filter: " + filter) - .append(filter.accept(visitor, scope)) + .append(compiler.compile(filter, scope)) .comment("if (wasNull) return false;") .getVariable(wasNullVariable) .ifFalseGoto(end) @@ -284,10 +301,14 @@ private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinde .retBoolean(); } - private void generateProjectMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, String methodName, RowExpression projection) + private void generateProjectMethod( + ClassDefinition classDefinition, + CallSiteBinder callSiteBinder, + CachedInstanceBinder cachedInstanceBinder, + PreGeneratedExpressions preGeneratedExpressions, + String methodName, + RowExpression projection) { - PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, projection, methodName); - Parameter session = arg("session", ConnectorSession.class); Parameter cursor = arg("cursor", RecordCursor.class); Parameter output = arg("output", BlockBuilder.class); @@ -298,7 +319,7 @@ private void generateProjectMethod(ClassDefinition classDefinition, CallSiteBind Scope scope = method.getScope(); Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); - BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor( + RowExpressionCompiler compiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(cursor), @@ -310,14 +331,14 @@ private void generateProjectMethod(ClassDefinition classDefinition, CallSiteBind .putVariable(wasNullVariable, false) .getVariable(output) .comment("evaluate projection: " + projection.toString()) - .append(projection.accept(visitor, scope)) + .append(compiler.compile(projection, scope)) .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) .ret(); } - private static RowExpressionVisitor fieldReferenceCompiler(Variable cursorVariable) + private static RowExpressionVisitor fieldReferenceCompiler(Variable cursorVariable) { - return new RowExpressionVisitor() + return new RowExpressionVisitor() { @Override public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java index 99f0987e12e49..89fdc8c224614 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java @@ -114,7 +114,7 @@ public Supplier compilePageProcessor(Optional filt }; } - private Class compile(Optional filter, List projections, BodyCompiler bodyCompiler, Class superType) + private Class compile(Optional filter, List projections, BodyCompiler bodyCompiler, Class superType) { // create filter and project page iterator class try { @@ -128,7 +128,7 @@ private Class compile(Optional filter, List Class compileProcessor( RowExpression filter, List projections, - BodyCompiler bodyCompiler, + BodyCompiler bodyCompiler, Class superType) { ClassDefinition classDefinition = new ClassDefinition( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java index 9f8fefc92c4db..01e83080f5708 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/FunctionCallCodeGenerator.java @@ -34,8 +34,9 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon ScalarFunctionImplementation function = registry.getScalarFunctionImplementation(signature); List argumentsBytecode = new ArrayList<>(); - for (RowExpression argument : arguments) { - argumentsBytecode.add(context.generate(argument)); + for (int i = 0; i < arguments.size(); i++) { + RowExpression argument = arguments.get(i); + argumentsBytecode.add(context.generate(argument, function.getLambdaInterface().get(i))); } return context.generateCall(signature.getName(), function, argumentsBytecode); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/InputReferenceCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/InputReferenceCompiler.java index 8ee458b3caa49..6abab878b1ed3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/InputReferenceCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/InputReferenceCompiler.java @@ -34,7 +34,7 @@ import static java.util.Objects.requireNonNull; class InputReferenceCompiler - implements RowExpressionVisitor + implements RowExpressionVisitor { private final BiFunction blockResolver; private final BiFunction positionResolver; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsolatedClass.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsolatedClass.java index 31e8ae8d479c0..c70be35d3b81b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/IsolatedClass.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/IsolatedClass.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.gen; import com.facebook.presto.bytecode.DynamicClassLoader; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.io.ByteStreams; @@ -24,6 +23,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; public final class IsolatedClass { @@ -58,13 +58,12 @@ public static Class isolateClass( private static byte[] getBytecode(Class clazz) { - InputStream stream = clazz.getClassLoader().getResourceAsStream(clazz.getName().replace('.', '/') + ".class"); - checkArgument(stream != null, "Could not obtain byte code for class %s", clazz.getName()); - try { + try (InputStream stream = clazz.getClassLoader().getResourceAsStream(clazz.getName().replace('.', '/') + ".class")) { + checkArgument(stream != null, "Could not obtain byte code for class %s", clazz.getName()); return ByteStreams.toByteArray(stream); } catch (IOException e) { - throw Throwables.propagate(e); + throw new RuntimeException(format("Could not obtain byte code for class %s", clazz.getName()), e); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java index c013c0812bf7b..699111aa2bb61 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java @@ -48,6 +48,7 @@ import com.google.common.util.concurrent.ExecutionError; import com.google.common.util.concurrent.UncheckedExecutionException; import it.unimi.dsi.fastutil.longs.LongArrayList; +import org.openjdk.jol.info.ClassLayout; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; @@ -62,6 +63,7 @@ import static com.facebook.presto.bytecode.Access.FINAL; import static com.facebook.presto.bytecode.Access.PRIVATE; import static com.facebook.presto.bytecode.Access.PUBLIC; +import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.CompilerUtils.defineClass; import static com.facebook.presto.bytecode.CompilerUtils.makeClassName; @@ -72,6 +74,7 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantLong; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.notEqual; import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; @@ -183,6 +186,21 @@ private LookupSourceSupplierFactory internalCompileLookupSourceFactory(List internalCompileHashStrategy(List types, List outputChannels, List joinChannels, Optional sortChannel) { CallSiteBinder callSiteBinder = new CallSiteBinder(); @@ -193,6 +211,7 @@ private Class internalCompileHashStrategy(List channelFields = new ArrayList<>(); for (int i = 0; i < types.size(); i++) { @@ -208,7 +227,7 @@ private Class internalCompileHashStrategy(List internalCompileHashStrategy(List internalCompileHashStrategy(List joinChannels, FieldDefinition sizeField, + FieldDefinition instanceSizeField, List channelFields, List joinChannelFields, FieldDefinition hashChannelField) @@ -246,8 +267,8 @@ private static void generateConstructor(ClassDefinition classDefinition, .append(thisVariable) .invokeConstructor(Object.class); - constructor.comment("this.size = 0") - .append(thisVariable.setField(sizeField, constantLong(0L))); + constructor.comment("this.size = INSTANCE_SIZE") + .append(thisVariable.setField(sizeField, getStatic(instanceSizeField))); constructor.comment("Set channel fields"); @@ -275,8 +296,7 @@ private static void generateConstructor(ClassDefinition classDefinition, .append( channel.invoke("get", Object.class, blockIndex) .cast(type(Block.class)) - .invoke("getRetainedSizeInBytes", int.class) - .cast(long.class)) + .invoke("getRetainedSizeInBytes", long.class)) .longAdd() .putField(sizeField); } @@ -689,7 +709,7 @@ private static void generatePositionEqualsPositionMethod( .retInt(); } - private static void generateCompareMethod( + private static void generateCompareSortChannelPositionsMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List types, @@ -702,7 +722,7 @@ private static void generateCompareMethod( Parameter rightBlockPosition = arg("rightBlockPosition", int.class); MethodDefinition compareMethod = classDefinition.declareMethod( a(PUBLIC), - "compare", + "compareSortChannelPositions", type(int.class), leftBlockIndex, leftBlockPosition, @@ -738,6 +758,43 @@ private static void generateCompareMethod( .append(comparison); } + private static void generateIsSortChannelPositionNull( + ClassDefinition classDefinition, + List channelFields, + Optional sortChannel) + { + Parameter blockIndex = arg("blockIndex", int.class); + Parameter blockPosition = arg("blockPosition", int.class); + MethodDefinition isSortChannelPositionNullMethod = classDefinition.declareMethod( + a(PUBLIC), + "isSortChannelPositionNull", + type(boolean.class), + blockIndex, + blockPosition); + + if (!sortChannel.isPresent()) { + isSortChannelPositionNullMethod.getBody() + .append(newInstance(UnsupportedOperationException.class)) + .throwObject(); + return; + } + + Variable thisVariable = isSortChannelPositionNullMethod.getThis(); + + int index = sortChannel.get().getChannel(); + + BytecodeExpression block = thisVariable + .getField(channelFields.get(index)) + .invoke("get", Object.class, blockIndex) + .cast(Block.class); + + BytecodeNode isNull = block.invoke("isNull", boolean.class, blockPosition).ret(); + + isSortChannelPositionNullMethod + .getBody() + .append(isNull); + } + private static BytecodeNode typeEquals( BytecodeExpression type, BytecodeExpression leftBlock, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java index 27cb1976c38e6..ecff33aba4465 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java @@ -29,6 +29,7 @@ import com.facebook.presto.operator.StandardJoinFilterFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda; import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; @@ -146,11 +147,17 @@ private void generateMethods(ClassDefinition classDefinition, CallSiteBinder cal FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class); - generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter, leftBlocksSize, sessionField); - generateConstructor(classDefinition, sessionField, cachedInstanceBinder); + PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, leftBlocksSize, filter); + generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, preGeneratedExpressions, filter, leftBlocksSize, sessionField); + + generateConstructor(classDefinition, sessionField, cachedInstanceBinder, preGeneratedExpressions); } - private static void generateConstructor(ClassDefinition classDefinition, FieldDefinition sessionField, CachedInstanceBinder cachedInstanceBinder) + private static void generateConstructor( + ClassDefinition classDefinition, + FieldDefinition sessionField, + CachedInstanceBinder cachedInstanceBinder, + PreGeneratedExpressions preGeneratedExpressions) { Parameter sessionParameter = arg("session", ConnectorSession.class); MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), sessionParameter); @@ -164,13 +171,21 @@ private static void generateConstructor(ClassDefinition classDefinition, FieldDe body.append(thisVariable.setField(sessionField, sessionParameter)); cachedInstanceBinder.generateInitializations(thisVariable, body); + for (CompiledLambda compiledLambda : preGeneratedExpressions.getCompiledLambdaMap().values()) { + compiledLambda.generateInitialization(thisVariable, body); + } body.ret(); } - private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpression filter, int leftBlocksSize, FieldDefinition sessionField) + private void generateFilterMethod( + ClassDefinition classDefinition, + CallSiteBinder callSiteBinder, + CachedInstanceBinder cachedInstanceBinder, + PreGeneratedExpressions preGeneratedExpressions, + RowExpression filter, + int leftBlocksSize, + FieldDefinition sessionField) { - PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, leftBlocksSize, filter); - // int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks Parameter leftPosition = arg("leftPosition", int.class); Parameter leftBlocks = arg("leftBlocks", Block[].class); @@ -195,14 +210,14 @@ private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinde Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); scope.declareVariable("session", body, method.getThis().getField(sessionField)); - BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor( + RowExpressionCompiler compiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder, leftPosition, leftBlocks, rightPosition, rightBlocks, leftBlocksSize), metadata.getFunctionRegistry(), preGeneratedExpressions); - BytecodeNode visitorBody = filter.accept(visitor, scope); + BytecodeNode visitorBody = compiler.compile(filter, scope); Variable result = scope.declareVariable(boolean.class, "result"); body.append(visitorBody) @@ -222,7 +237,7 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( { Set lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(filter)); ImmutableMap.Builder tryMethodMap = ImmutableMap.builder(); - ImmutableMap.Builder lambdaFieldMap = ImmutableMap.builder(); + ImmutableMap.Builder compiledLambdaMap = ImmutableMap.builder(); int counter = 0; for (RowExpression expression : lambdaAndTryExpressions) { @@ -236,12 +251,12 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( Parameter rightPosition = arg("rightPosition", int.class); Parameter rightBlocks = arg("rightBlocks", Block[].class); - BytecodeExpressionVisitor innerExpressionVisitor = new BytecodeExpressionVisitor( + RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder, leftPosition, leftBlocks, rightPosition, rightBlocks, leftBlocksSize), metadata.getFunctionRegistry(), - new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build())); + new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build())); List inputParameters = ImmutableList.builder() .add(session) @@ -252,7 +267,7 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( .build(); MethodDefinition tryMethod = defineTryMethod( - innerExpressionVisitor, + innerExpressionCompiler, containerClassDefinition, "try_" + counter, inputParameters, @@ -264,8 +279,8 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( } else if (expression instanceof LambdaDefinitionExpression) { LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) expression; - PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); - FieldDefinition methodHandleField = LambdaBytecodeGenerator.preGenerateLambdaExpression( + PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build()); + CompiledLambda compiledLambda = LambdaBytecodeGenerator.preGenerateLambdaExpression( lambdaExpression, "lambda_" + counter, containerClassDefinition, @@ -273,7 +288,7 @@ else if (expression instanceof LambdaDefinitionExpression) { callSiteBinder, cachedInstanceBinder, metadata.getFunctionRegistry()); - lambdaFieldMap.put(lambdaExpression, methodHandleField); + compiledLambdaMap.put(lambdaExpression, compiledLambda); } else { throw new VerifyException(format("unexpected expression: %s", expression.toString())); @@ -281,7 +296,7 @@ else if (expression instanceof LambdaDefinitionExpression) { counter++; } - return new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); + return new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build()); } private static void generateToString(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, String string) @@ -303,7 +318,7 @@ default Optional getSortChannel() } } - private static RowExpressionVisitor fieldReferenceCompiler( + private static RowExpressionVisitor fieldReferenceCompiler( final CallSiteBinder callSiteBinder, final Variable leftPosition, final Variable leftBlocks, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaAndTryExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaAndTryExpressionExtractor.java index 0bf1ea3c78170..a8437a11232ab 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaAndTryExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaAndTryExpressionExtractor.java @@ -44,7 +44,7 @@ public static List extractLambdaAndTryExpressions(RowExpression e } private static class Visitor - implements RowExpressionVisitor + implements RowExpressionVisitor { private final ImmutableList.Builder lambdaAndTryExpressions = ImmutableList.builder(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java index 40a992d7a46cb..ca172e4abdb40 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaBytecodeGenerator.java @@ -19,8 +19,10 @@ import com.facebook.presto.bytecode.FieldDefinition; import com.facebook.presto.bytecode.MethodDefinition; import com.facebook.presto.bytecode.Parameter; +import com.facebook.presto.bytecode.ParameterizedType; import com.facebook.presto.bytecode.Scope; import com.facebook.presto.bytecode.Variable; +import com.facebook.presto.bytecode.expression.BytecodeExpression; import com.facebook.presto.bytecode.expression.BytecodeExpressions; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.ConnectorSession; @@ -28,14 +30,21 @@ import com.facebook.presto.sql.relational.ConstantExpression; import com.facebook.presto.sql.relational.InputReferenceExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; +import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.RowExpressionVisitor; import com.facebook.presto.sql.relational.VariableReferenceExpression; import com.facebook.presto.util.Reflection; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Primitives; +import org.objectweb.asm.Handle; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; import java.lang.invoke.MethodHandle; +import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -47,13 +56,22 @@ import static com.facebook.presto.bytecode.Parameter.arg; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantClass; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeDynamic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray; import static com.facebook.presto.bytecode.expression.BytecodeExpressions.setStatic; +import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary; import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary; +import static com.facebook.presto.sql.gen.LambdaCapture.LAMBDA_CAPTURE_METHOD; +import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static org.objectweb.asm.Type.getMethodType; +import static org.objectweb.asm.Type.getType; public class LambdaBytecodeGenerator { @@ -64,7 +82,7 @@ private LambdaBytecodeGenerator() /** * @return a MethodHandle field that represents the lambda expression */ - public static FieldDefinition preGenerateLambdaExpression( + public static CompiledLambda preGenerateLambdaExpression( LambdaDefinitionExpression lambdaExpression, String fieldName, ClassDefinition classDefinition, @@ -85,7 +103,7 @@ public static FieldDefinition preGenerateLambdaExpression( parameterMapBuilder.put(argumentName, new ParameterAndType(arg, type)); } - BytecodeExpressionVisitor innerExpressionVisitor = new BytecodeExpressionVisitor( + RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, variableReferenceCompiler(parameterMapBuilder.build()), @@ -93,15 +111,15 @@ public static FieldDefinition preGenerateLambdaExpression( preGeneratedExpressions); return defineLambdaMethodAndField( - innerExpressionVisitor, + innerExpressionCompiler, classDefinition, fieldName, parameters.build(), lambdaExpression); } - private static FieldDefinition defineLambdaMethodAndField( - BytecodeExpressionVisitor innerExpressionVisitor, + private static CompiledLambda defineLambdaMethodAndField( + RowExpressionCompiler innerExpressionCompiler, ClassDefinition classDefinition, String fieldAndMethodName, List inputParameters, @@ -112,18 +130,19 @@ private static FieldDefinition defineLambdaMethodAndField( Scope scope = method.getScope(); Variable wasNull = scope.declareVariable(boolean.class, "wasNull"); - BytecodeNode compiledBody = lambda.getBody().accept(innerExpressionVisitor, scope); + BytecodeNode compiledBody = innerExpressionCompiler.compile(lambda.getBody(), scope); method.getBody() .putVariable(wasNull, false) .append(compiledBody) .append(boxPrimitiveIfNecessary(scope, returnType)) .ret(returnType); - FieldDefinition methodHandleField = classDefinition.declareField(a(PRIVATE, STATIC, FINAL), fieldAndMethodName, type(MethodHandle.class)); + FieldDefinition staticField = classDefinition.declareField(a(PRIVATE, STATIC, FINAL), fieldAndMethodName, type(MethodHandle.class)); + FieldDefinition instanceField = classDefinition.declareField(a(PRIVATE, FINAL), "binded_" + fieldAndMethodName, type(MethodHandle.class)); classDefinition.getClassInitializer().getBody() .append(setStatic( - methodHandleField, + staticField, invokeStatic( Reflection.class, "methodHandle", @@ -136,12 +155,91 @@ private static FieldDefinition defineLambdaMethodAndField( .map(Parameter::getType) .map(BytecodeExpressions::constantClass) .collect(toImmutableList()))))); - return methodHandleField; + + Handle lambdaAsmHandle = new Handle( + Opcodes.H_INVOKEVIRTUAL, + method.getThis().getType().getClassName(), + method.getName(), + method.getMethodDescriptor()); + + return new CompiledLambda( + lambdaAsmHandle, + method.getReturnType(), + method.getParameterTypes(), + staticField, + instanceField); + } + + public static BytecodeNode generateLambda( + BytecodeGeneratorContext context, + List captureExpressions, + CompiledLambda compiledLambda, + Class lambdaInterface) + { + if (!lambdaInterface.isAnnotationPresent(FunctionalInterface.class)) { + // lambdaInterface is checked to be annotated with FunctionalInterface when generating ScalarFunctionImplementation + throw new VerifyException("lambda should be generated as class annotated with FunctionalInterface"); + } + + BytecodeBlock block = new BytecodeBlock().setDescription("Partial apply"); + Scope scope = context.getScope(); + + Variable wasNull = scope.getVariable("wasNull"); + + // generate values to be captured + ImmutableList.Builder captureVariableBuilder = ImmutableList.builder(); + for (RowExpression captureExpression : captureExpressions) { + Class valueType = Primitives.wrap(captureExpression.getType().getJavaType()); + Variable valueVariable = scope.createTempVariable(valueType); + block.append(context.generate(captureExpression)); + block.append(boxPrimitiveIfNecessary(scope, valueType)); + block.putVariable(valueVariable); + block.append(wasNull.set(constantFalse())); + captureVariableBuilder.add(valueVariable); + } + + List captureVariables = ImmutableList.builder() + .add(scope.getThis(), scope.getVariable("session")) + .addAll(captureVariableBuilder.build()) + .build(); + + Type instantiatedMethodAsmType = getMethodType( + compiledLambda.getReturnType().getAsmType(), + compiledLambda.getParameterTypes().stream() + .skip(captureExpressions.size() + 1) // skip capture variables and ConnectorSession + .map(ParameterizedType::getAsmType) + .collect(toImmutableList()).toArray(new Type[0])); + + block.append( + invokeDynamic( + LAMBDA_CAPTURE_METHOD, + ImmutableList.of( + getType(getSingleApplyMethod(lambdaInterface)), + compiledLambda.getLambdaAsmHandle(), + instantiatedMethodAsmType + ), + "apply", + type(lambdaInterface), + captureVariables) + ); + return block; + } + + private static Method getSingleApplyMethod(Class lambdaFunctionInterface) + { + checkCondition(lambdaFunctionInterface.isAnnotationPresent(FunctionalInterface.class), COMPILER_ERROR, "Lambda function interface is required to be annotated with FunctionalInterface"); + + List applyMethods = Arrays.stream(lambdaFunctionInterface.getMethods()) + .filter(method -> method.getName().equals("apply")) + .collect(toImmutableList()); + + checkCondition(applyMethods.size() == 1, COMPILER_ERROR, "Expect to have exactly 1 method with name 'apply' in interface " + lambdaFunctionInterface.getName()); + return applyMethods.get(0); } - private static RowExpressionVisitor variableReferenceCompiler(Map parameterMap) + private static RowExpressionVisitor variableReferenceCompiler(Map parameterMap) { - return new RowExpressionVisitor() + return new RowExpressionVisitor() { @Override public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope) @@ -179,4 +277,58 @@ public BytecodeNode visitVariableReference(VariableReferenceExpression reference } }; } + + static class CompiledLambda + { + private final FieldDefinition staticField; + // the instance field will be binded to "this" in constructor + private final FieldDefinition instanceField; + + // lambda method information + private final Handle lambdaAsmHandle; + private final ParameterizedType returnType; + private final List parameterTypes; + + public CompiledLambda( + Handle lambdaAsmHandle, + ParameterizedType returnType, + List parameterTypes, + FieldDefinition staticField, + FieldDefinition instanceField) + { + this.staticField = requireNonNull(staticField, "staticField is null"); + this.instanceField = requireNonNull(instanceField, "instanceField is null"); + this.lambdaAsmHandle = requireNonNull(lambdaAsmHandle, "lambdaMethodAsmHandle is null"); + this.returnType = requireNonNull(returnType, "returnType is null"); + this.parameterTypes = ImmutableList.copyOf(requireNonNull(parameterTypes, "returnType is null")); + } + + public Handle getLambdaAsmHandle() + { + return lambdaAsmHandle; + } + + public ParameterizedType getReturnType() + { + return returnType; + } + + public List getParameterTypes() + { + return parameterTypes; + } + + public FieldDefinition getInstanceField() + { + return instanceField; + } + + public void generateInitialization(Variable thisVariable, BytecodeBlock block) + { + block.append( + thisVariable.setField( + instanceField, + getStatic(staticField).invoke("bindTo", MethodHandle.class, thisVariable.cast(Object.class)))); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaCapture.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaCapture.java new file mode 100644 index 0000000000000..c91b1f6948223 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/LambdaCapture.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.gen; + +import com.google.common.base.Throwables; + +import java.lang.invoke.CallSite; +import java.lang.invoke.LambdaConversionException; +import java.lang.invoke.LambdaMetafactory; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.reflect.Method; + +import static com.google.common.base.Throwables.throwIfUnchecked; + +public final class LambdaCapture +{ + public static final Method LAMBDA_CAPTURE_METHOD; + + static { + try { + LAMBDA_CAPTURE_METHOD = LambdaCapture.class.getMethod("lambdaCapture", MethodHandles.Lookup.class, String.class, MethodType.class, MethodType.class, MethodHandle.class, MethodType.class); + } + catch (NoSuchMethodException e) { + throw Throwables.propagate(e); + } + } + + private LambdaCapture() + { + } + + public static CallSite lambdaCapture( + MethodHandles.Lookup callerLookup, + String name, + MethodType type, + MethodType samMethodType, + MethodHandle implMethod, + MethodType instantiatedMethodType) + { + try { + // delegate to metafactory, we may choose to generate code ourselves in the future. + return LambdaMetafactory.metafactory( + callerLookup, + name, + type, + samMethodType, + implMethod, + instantiatedMethodType); + } + catch (LambdaConversionException e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java index dcf8503409a5f..bdb876e77d52d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java @@ -39,9 +39,11 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.ConstantExpression; import com.facebook.presto.sql.relational.DeterminismEvaluator; +import com.facebook.presto.sql.relational.Expressions; import com.facebook.presto.sql.relational.InputReferenceExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; @@ -57,6 +59,7 @@ import java.util.List; import java.util.Set; +import java.util.TreeSet; import java.util.function.Consumer; import java.util.function.Supplier; @@ -127,7 +130,7 @@ public Supplier compileProjection(RowExpression projection) projectionClass = defineClass(classDefinition, PageProjection.class, callSiteBinder.getBindings(), getClass().getClassLoader()); } catch (Exception e) { - throw new PrestoException(COMPILER_ERROR, e.getCause()); + throw new PrestoException(COMPILER_ERROR, e); } return () -> { @@ -151,8 +154,11 @@ private ClassDefinition defineProjectionClass(RowExpression projection, InputCha FieldDefinition blockBuilderField = classDefinition.declareField(a(PRIVATE), "blockBuilder", BlockBuilder.class); CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); + generatePageProjectMethod(classDefinition, blockBuilderField); - generateProjectMethod(classDefinition, callSiteBinder, cachedInstanceBinder, projection, blockBuilderField); + + PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, projection); + generateProjectMethod(classDefinition, callSiteBinder, cachedInstanceBinder, preGeneratedExpressions, projection, blockBuilderField); // getType BytecodeExpression type = invoke(callSiteBinder.bind(projection.getType(), Type.class), "type"); @@ -181,7 +187,7 @@ private ClassDefinition defineProjectionClass(RowExpression projection, InputCha .append(invoke(callSiteBinder.bind(toStringResult, String.class), "toString").ret()); // constructor - generateConstructor(classDefinition, cachedInstanceBinder, method -> { + generateConstructor(classDefinition, cachedInstanceBinder, preGeneratedExpressions, method -> { Variable thisVariable = method.getThis(); BytecodeBlock body = method.getBody(); body.append(thisVariable.setField( @@ -217,6 +223,11 @@ private static MethodDefinition generatePageProjectMethod(ClassDefinition classD Variable positions = scope.declareVariable(int[].class, "positions"); Variable index = scope.declareVariable(int.class, "index"); + // reset block builder before using since a previous run may have thrown leaving data in the block builder + body.append(thisVariable.setField( + blockBuilder, + thisVariable.getField(blockBuilder).invoke("newBlockBuilderLike", BlockBuilder.class, newInstance(BlockBuilderStatus.class)))); + IfStatement ifStatement = new IfStatement() .condition(selectedPositions.invoke("isList", boolean.class)); body.append(ifStatement); @@ -237,10 +248,7 @@ private static MethodDefinition generatePageProjectMethod(ClassDefinition classD Variable block = scope.declareVariable(Block.class, "block"); body.append(block.set(thisVariable.getField(blockBuilder).invoke("build", Block.class))) - .append(thisVariable.setField( - blockBuilder, - thisVariable.getField(blockBuilder).invoke("newBlockBuilderLike", BlockBuilder.class, newInstance(BlockBuilderStatus.class)))) - .append(block.ret()); + .append(block.ret()); return method; } @@ -249,11 +257,10 @@ private MethodDefinition generateProjectMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, + PreGeneratedExpressions preGeneratedExpressions, RowExpression projection, FieldDefinition blockBuilder) { - PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, projection); - Parameter session = arg("session", ConnectorSession.class); Parameter page = arg("page", Page.class); Parameter position = arg("position", int.class); @@ -274,8 +281,10 @@ private MethodDefinition generateProjectMethod( BytecodeBlock body = method.getBody(); Variable thisVariable = method.getThis(); + declareBlockVariables(projection, page, scope, body); + Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); - BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor( + RowExpressionCompiler compiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder), @@ -283,7 +292,7 @@ private MethodDefinition generateProjectMethod( preGeneratedExpressions); body.append(thisVariable.getField(blockBuilder)) - .append(projection.accept(visitor, scope)) + .append(compiler.compile(projection, scope)) .append(generateWrite(callSiteBinder, scope, wasNullVariable, projection.getType())) .ret(); return method; @@ -325,7 +334,9 @@ private ClassDefinition defineFilterClass(RowExpression filter, InputChannels in type(PageFilter.class)); CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); - generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter); + + PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, filter); + generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, preGeneratedExpressions, filter); FieldDefinition selectedPositions = classDefinition.declareField(a(PRIVATE), "selectedPositions", boolean[].class); generatePageFilterMethod(classDefinition, selectedPositions); @@ -354,7 +365,7 @@ private ClassDefinition defineFilterClass(RowExpression filter, InputChannels in .retObject(); // constructor - generateConstructor(classDefinition, cachedInstanceBinder, method -> { + generateConstructor(classDefinition, cachedInstanceBinder, preGeneratedExpressions, method -> { Variable thisVariable = method.getScope().getThis(); method.getBody().append(thisVariable.setField(selectedPositions, newArray(type(boolean[].class), 0))); }); @@ -409,10 +420,9 @@ private MethodDefinition generateFilterMethod( ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, + PreGeneratedExpressions preGeneratedExpressions, RowExpression filter) { - PreGeneratedExpressions preGeneratedExpressions = generateMethodsForLambdaAndTry(classDefinition, callSiteBinder, cachedInstanceBinder, filter); - Parameter session = arg("session", ConnectorSession.class); Parameter page = arg("page", Page.class); Parameter position = arg("position", int.class); @@ -432,8 +442,10 @@ private MethodDefinition generateFilterMethod( Scope scope = method.getScope(); BytecodeBlock body = method.getBody(); + declareBlockVariables(filter, page, scope, body); + Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); - BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor( + RowExpressionCompiler compiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder), @@ -441,7 +453,7 @@ private MethodDefinition generateFilterMethod( preGeneratedExpressions); Variable result = scope.declareVariable(boolean.class, "result"); - body.append(filter.accept(visitor, scope)) + body.append(compiler.compile(filter, scope)) // store result so we can check for null .putVariable(result) .append(and(not(wasNullVariable), result).ret()); @@ -456,7 +468,7 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( { Set lambdaAndTryExpressions = ImmutableSet.copyOf(extractLambdaAndTryExpressions(expression)); ImmutableMap.Builder tryMethodMap = ImmutableMap.builder(); - ImmutableMap.Builder lambdaFieldMap = ImmutableMap.builder(); + ImmutableMap.Builder compiledLambdaMap = ImmutableMap.builder(); int counter = 0; for (RowExpression lambdaOrTryExpression : lambdaAndTryExpressions) { @@ -465,24 +477,24 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( verify(!Signatures.TRY.equals(tryExpression.getSignature().getName())); Parameter session = arg("session", ConnectorSession.class); - Parameter page = arg("page", Page.class); + List blocks = toBlockParameters(getInputChannels(tryExpression.getArguments())); Parameter position = arg("position", int.class); - BytecodeExpressionVisitor innerExpressionVisitor = new BytecodeExpressionVisitor( + RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler( callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder), metadata.getFunctionRegistry(), - new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build())); + new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build())); List inputParameters = ImmutableList.builder() .add(session) - .add(page) + .addAll(blocks) .add(position) .build(); MethodDefinition tryMethod = defineTryMethod( - innerExpressionVisitor, + innerExpressionCompiler, containerClassDefinition, "try_" + counter, inputParameters, @@ -494,8 +506,8 @@ private PreGeneratedExpressions generateMethodsForLambdaAndTry( } else if (lambdaOrTryExpression instanceof LambdaDefinitionExpression) { LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) lambdaOrTryExpression; - PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); - FieldDefinition methodHandleField = LambdaBytecodeGenerator.preGenerateLambdaExpression( + PreGeneratedExpressions preGeneratedExpressions = new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build()); + CompiledLambda compiledLambda = LambdaBytecodeGenerator.preGenerateLambdaExpression( lambdaExpression, "lambda_" + counter, containerClassDefinition, @@ -503,7 +515,7 @@ else if (lambdaOrTryExpression instanceof LambdaDefinitionExpression) { callSiteBinder, cachedInstanceBinder, metadata.getFunctionRegistry()); - lambdaFieldMap.put(lambdaExpression, methodHandleField); + compiledLambdaMap.put(lambdaExpression, compiledLambda); } else { throw new VerifyException(format("unexpected expression: %s", lambdaOrTryExpression.toString())); @@ -511,12 +523,13 @@ else if (lambdaOrTryExpression instanceof LambdaDefinitionExpression) { counter++; } - return new PreGeneratedExpressions(tryMethodMap.build(), lambdaFieldMap.build()); + return new PreGeneratedExpressions(tryMethodMap.build(), compiledLambdaMap.build()); } private static void generateConstructor( ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder, + PreGeneratedExpressions preGeneratedExpressions, Consumer additionalStatements) { MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC)); @@ -531,13 +544,48 @@ private static void generateConstructor( additionalStatements.accept(constructorDefinition); cachedInstanceBinder.generateInitializations(thisVariable, body); + for (CompiledLambda compiledLambda : preGeneratedExpressions.getCompiledLambdaMap().values()) { + compiledLambda.generateInitialization(thisVariable, body); + } body.ret(); } - private static RowExpressionVisitor fieldReferenceCompiler(CallSiteBinder callSiteBinder) + private static void declareBlockVariables(RowExpression expression, Parameter page, Scope scope, BytecodeBlock body) + { + for (int channel : getInputChannels(expression)) { + scope.declareVariable("block_" + channel, body, page.invoke("getBlock", Block.class, constantInt(channel))); + } + } + + private static List getInputChannels(Iterable expressions) + { + TreeSet channels = new TreeSet<>(); + for (RowExpression expression : Expressions.subExpressions(expressions)) { + if (expression instanceof InputReferenceExpression) { + channels.add(((InputReferenceExpression) expression).getField()); + } + } + return ImmutableList.copyOf(channels); + } + + private static List getInputChannels(RowExpression expression) + { + return getInputChannels(ImmutableList.of(expression)); + } + + private static List toBlockParameters(List inputChannels) + { + ImmutableList.Builder parameters = ImmutableList.builder(); + for (int channel : inputChannels) { + parameters.add(arg("block_" + channel, Block.class)); + } + return parameters.build(); + } + + private static RowExpressionVisitor fieldReferenceCompiler(CallSiteBinder callSiteBinder) { return new InputReferenceCompiler( - (scope, field) -> scope.getVariable("page").invoke("getBlock", Block.class, constantInt(field)), + (scope, field) -> scope.getVariable("block_" + field), (scope, field) -> scope.getVariable("position"), callSiteBinder); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/PreGeneratedExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/PreGeneratedExpressions.java index c17196a3d6071..30bbb027bdc26 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/PreGeneratedExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/PreGeneratedExpressions.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.sql.gen; -import com.facebook.presto.bytecode.FieldDefinition; import com.facebook.presto.bytecode.MethodDefinition; +import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.google.common.collect.ImmutableMap; @@ -26,12 +26,12 @@ public class PreGeneratedExpressions { private final Map tryMethodMap; - private final Map lambdaFieldMap; + private final Map compiledLambdaMap; - public PreGeneratedExpressions(Map tryMethodMap, Map lambdaFieldMap) + public PreGeneratedExpressions(Map tryMethodMap, Map compiledLambdaMap) { this.tryMethodMap = ImmutableMap.copyOf(requireNonNull(tryMethodMap, "tryMethodMap is null")); - this.lambdaFieldMap = ImmutableMap.copyOf(requireNonNull(lambdaFieldMap, "lambdaFieldMap is null")); + this.compiledLambdaMap = ImmutableMap.copyOf(requireNonNull(compiledLambdaMap, "compiledLambdaMap is null")); } public Map getTryMethodMap() @@ -39,8 +39,8 @@ public Map getTryMethodMap() return tryMethodMap; } - public Map getLambdaFieldMap() + public Map getCompiledLambdaMap() { - return lambdaFieldMap; + return compiledLambdaMap; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java new file mode 100644 index 0000000000000..5fcc7165d5336 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/RowExpressionCompiler.java @@ -0,0 +1,260 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.gen; + +import com.facebook.presto.bytecode.BytecodeBlock; +import com.facebook.presto.bytecode.BytecodeNode; +import com.facebook.presto.bytecode.Scope; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.sql.relational.CallExpression; +import com.facebook.presto.sql.relational.ConstantExpression; +import com.facebook.presto.sql.relational.InputReferenceExpression; +import com.facebook.presto.sql.relational.LambdaDefinitionExpression; +import com.facebook.presto.sql.relational.RowExpression; +import com.facebook.presto.sql.relational.RowExpressionVisitor; +import com.facebook.presto.sql.relational.VariableReferenceExpression; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; + +import java.util.Optional; + +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue; +import static com.facebook.presto.bytecode.instruction.Constant.loadBoolean; +import static com.facebook.presto.bytecode.instruction.Constant.loadDouble; +import static com.facebook.presto.bytecode.instruction.Constant.loadFloat; +import static com.facebook.presto.bytecode.instruction.Constant.loadInt; +import static com.facebook.presto.bytecode.instruction.Constant.loadLong; +import static com.facebook.presto.bytecode.instruction.Constant.loadString; +import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant; +import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.generateLambda; +import static com.facebook.presto.sql.relational.Signatures.BIND; +import static com.facebook.presto.sql.relational.Signatures.CAST; +import static com.facebook.presto.sql.relational.Signatures.COALESCE; +import static com.facebook.presto.sql.relational.Signatures.DEREFERENCE; +import static com.facebook.presto.sql.relational.Signatures.IF; +import static com.facebook.presto.sql.relational.Signatures.IN; +import static com.facebook.presto.sql.relational.Signatures.IS_NULL; +import static com.facebook.presto.sql.relational.Signatures.NULL_IF; +import static com.facebook.presto.sql.relational.Signatures.ROW_CONSTRUCTOR; +import static com.facebook.presto.sql.relational.Signatures.SWITCH; +import static com.facebook.presto.sql.relational.Signatures.TRY; +import static com.google.common.base.Preconditions.checkState; + +public class RowExpressionCompiler +{ + private final CallSiteBinder callSiteBinder; + private final CachedInstanceBinder cachedInstanceBinder; + private final RowExpressionVisitor fieldReferenceCompiler; + private final FunctionRegistry registry; + private final PreGeneratedExpressions preGeneratedExpressions; + + RowExpressionCompiler( + CallSiteBinder callSiteBinder, + CachedInstanceBinder cachedInstanceBinder, + RowExpressionVisitor fieldReferenceCompiler, + FunctionRegistry registry, + PreGeneratedExpressions preGeneratedExpressions) + { + this.callSiteBinder = callSiteBinder; + this.cachedInstanceBinder = cachedInstanceBinder; + this.fieldReferenceCompiler = fieldReferenceCompiler; + this.registry = registry; + this.preGeneratedExpressions = preGeneratedExpressions; + } + + public BytecodeNode compile(RowExpression rowExpression, Scope scope) + { + return compile(rowExpression, scope, Optional.empty()); + } + + public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional lambdaInterface) + { + return rowExpression.accept(new Visitor(), new Context(scope, lambdaInterface)); + } + + private class Visitor + implements RowExpressionVisitor + { + @Override + public BytecodeNode visitCall(CallExpression call, Context context) + { + BytecodeGenerator generator; + // special-cased in function registry + if (call.getSignature().getName().equals(CAST)) { + generator = new CastCodeGenerator(); + } + else { + switch (call.getSignature().getName()) { + // lazy evaluation + case IF: + generator = new IfCodeGenerator(); + break; + case NULL_IF: + generator = new NullIfCodeGenerator(); + break; + case SWITCH: + // (SWITCH (WHEN ) (WHEN ) ) + generator = new SwitchCodeGenerator(); + break; + case TRY: + generator = new TryCodeGenerator(preGeneratedExpressions.getTryMethodMap()); + break; + // functions that take null as input + case IS_NULL: + generator = new IsNullCodeGenerator(); + break; + case COALESCE: + generator = new CoalesceCodeGenerator(); + break; + // functions that require varargs and/or complex types (e.g., lists) + case IN: + generator = new InCodeGenerator(registry); + break; + // optimized implementations (shortcircuiting behavior) + case "AND": + generator = new AndCodeGenerator(); + break; + case "OR": + generator = new OrCodeGenerator(); + break; + case DEREFERENCE: + generator = new DereferenceCodeGenerator(); + break; + case ROW_CONSTRUCTOR: + generator = new RowConstructorCodeGenerator(); + break; + case BIND: + generator = new BindCodeGenerator(preGeneratedExpressions.getCompiledLambdaMap(), context.getLambdaInterface().get()); + break; + default: + generator = new FunctionCallCodeGenerator(); + } + } + + BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext( + RowExpressionCompiler.this, + context.getScope(), + callSiteBinder, + cachedInstanceBinder, + registry, + preGeneratedExpressions); + + return generator.generateExpression(call.getSignature(), generatorContext, call.getType(), call.getArguments()); + } + + @Override + public BytecodeNode visitConstant(ConstantExpression constant, Context context) + { + Object value = constant.getValue(); + Class javaType = constant.getType().getJavaType(); + + BytecodeBlock block = new BytecodeBlock(); + if (value == null) { + return block.comment("constant null") + .append(context.getScope().getVariable("wasNull").set(constantTrue())) + .pushJavaDefault(javaType); + } + + // use LDC for primitives (boolean, short, int, long, float, double) + block.comment("constant " + constant.getType().getTypeSignature()); + if (javaType == boolean.class) { + return block.append(loadBoolean((Boolean) value)); + } + if (javaType == byte.class || javaType == short.class || javaType == int.class) { + return block.append(loadInt(((Number) value).intValue())); + } + if (javaType == long.class) { + return block.append(loadLong((Long) value)); + } + if (javaType == float.class) { + return block.append(loadFloat((Float) value)); + } + if (javaType == double.class) { + return block.append(loadDouble((Double) value)); + } + if (javaType == String.class) { + return block.append(loadString((String) value)); + } + if (javaType == void.class) { + return block; + } + + // bind constant object directly into the call-site using invoke dynamic + Binding binding = callSiteBinder.bind(value, constant.getType().getJavaType()); + + return new BytecodeBlock() + .setDescription("constant " + constant.getType()) + .comment(constant.toString()) + .append(loadConstant(binding)); + } + + @Override + public BytecodeNode visitInputReference(InputReferenceExpression node, Context context) + { + return fieldReferenceCompiler.visitInputReference(node, context.getScope()); + } + + @Override + public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context context) + { + checkState(preGeneratedExpressions.getCompiledLambdaMap().containsKey(lambda), "lambda expressions map does not contain this lambda definition"); + if (!context.lambdaInterface.get().isAnnotationPresent(FunctionalInterface.class)) { + // lambdaInterface is checked to be annotated with FunctionalInterface when generating ScalarFunctionImplementation + throw new VerifyException("lambda should be generated as class annotated with FunctionalInterface"); + } + + BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext( + RowExpressionCompiler.this, + context.getScope(), + callSiteBinder, + cachedInstanceBinder, + registry, + preGeneratedExpressions); + + return generateLambda( + generatorContext, + ImmutableList.of(), + preGeneratedExpressions.getCompiledLambdaMap().get(lambda), + context.getLambdaInterface().get()); + } + + @Override + public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Context context) + { + return fieldReferenceCompiler.visitVariableReference(reference, context.getScope()); + } + } + + private static class Context + { + private final Scope scope; + private final Optional lambdaInterface; + + public Context(Scope scope, Optional lambdaInterface) + { + this.scope = scope; + this.lambdaInterface = lambdaInterface; + } + + public Scope getScope() + { + return scope; + } + + public Optional getLambdaInterface() + { + return lambdaInterface; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java index 2fc8baf6c501d..d7e86add779b3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/TryCodeGenerator.java @@ -57,6 +57,7 @@ public class TryCodeGenerator implements BytecodeGenerator { private static final String EXCEPTION_HANDLER_NAME = "tryExpressionExceptionHandler"; + private static final MethodHandle EXCEPTION_HANDLER = methodHandle(TryCodeGenerator.class, EXCEPTION_HANDLER_NAME, PrestoException.class); private final Map tryMethodsMap; @@ -86,7 +87,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon } public static MethodDefinition defineTryMethod( - BytecodeExpressionVisitor innerExpressionVisitor, + RowExpressionCompiler innerExpressionCompiler, ClassDefinition classDefinition, String methodName, List inputParameters, @@ -98,10 +99,10 @@ public static MethodDefinition defineTryMethod( Scope calleeMethodScope = method.getScope(); Variable wasNull = calleeMethodScope.declareVariable(boolean.class, "wasNull"); - BytecodeNode innerExpression = innerRowExpression.accept(innerExpressionVisitor, calleeMethodScope); + BytecodeNode innerExpression = innerExpressionCompiler.compile(innerRowExpression, calleeMethodScope); MethodType exceptionHandlerType = methodType(returnType, PrestoException.class); - MethodHandle exceptionHandler = methodHandle(TryCodeGenerator.class, EXCEPTION_HANDLER_NAME, PrestoException.class).asType(exceptionHandlerType); + MethodHandle exceptionHandler = EXCEPTION_HANDLER.asType(exceptionHandlerType); Binding binding = callSiteBinder.bind(exceptionHandler); method.comment("Try projection: %s", innerRowExpression.toString()); diff --git a/presto-hive-cdh5/src/main/java/com/facebook/presto/hive/HiveCdh5Plugin.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/BinaryFunctionInterface.java similarity index 75% rename from presto-hive-cdh5/src/main/java/com/facebook/presto/hive/HiveCdh5Plugin.java rename to presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/BinaryFunctionInterface.java index e0b59d415071d..dcf5cfa84fb4b 100644 --- a/presto-hive-cdh5/src/main/java/com/facebook/presto/hive/HiveCdh5Plugin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/BinaryFunctionInterface.java @@ -11,13 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.hive; +package com.facebook.presto.sql.gen.lambda; -public class HiveCdh5Plugin - extends HivePlugin +@FunctionalInterface +public interface BinaryFunctionInterface extends LambdaFunctionInterface { - public HiveCdh5Plugin() - { - super("hive-cdh5"); - } + Object apply(Object arg1, Object arg2); } diff --git a/presto-hive-cdh4/src/main/java/com/facebook/presto/hive/HiveCdh4Plugin.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/LambdaFunctionInterface.java similarity index 77% rename from presto-hive-cdh4/src/main/java/com/facebook/presto/hive/HiveCdh4Plugin.java rename to presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/LambdaFunctionInterface.java index 20ba182f873d1..04a42d07b8bdf 100644 --- a/presto-hive-cdh4/src/main/java/com/facebook/presto/hive/HiveCdh4Plugin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/LambdaFunctionInterface.java @@ -11,13 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.hive; +package com.facebook.presto.sql.gen.lambda; -public class HiveCdh4Plugin - extends HivePlugin -{ - public HiveCdh4Plugin() - { - super("hive-cdh4"); - } -} +public interface LambdaFunctionInterface {} diff --git a/presto-hive-hadoop1/src/main/java/com/facebook/presto/hive/HiveHadoop1Plugin.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/UnaryFunctionInterface.java similarity index 77% rename from presto-hive-hadoop1/src/main/java/com/facebook/presto/hive/HiveHadoop1Plugin.java rename to presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/UnaryFunctionInterface.java index fa83097ae9ca4..f8110c264ceea 100644 --- a/presto-hive-hadoop1/src/main/java/com/facebook/presto/hive/HiveHadoop1Plugin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/lambda/UnaryFunctionInterface.java @@ -11,13 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.hive; +package com.facebook.presto.sql.gen.lambda; -public class HiveHadoop1Plugin - extends HivePlugin +@FunctionalInterface +public interface UnaryFunctionInterface extends LambdaFunctionInterface { - public HiveHadoop1Plugin() - { - super("hive-hadoop1"); - } + Object apply(Object arg1); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DesugaringRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DesugaringRewriter.java index 95e92aa4332ed..28866019b8064 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DesugaringRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DesugaringRewriter.java @@ -20,9 +20,12 @@ import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.QualifiedName; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; import static com.facebook.presto.spi.type.TimeType.TIME; import static com.facebook.presto.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; @@ -33,18 +36,18 @@ public class DesugaringRewriter extends ExpressionRewriter { - private final IdentityLinkedHashMap expressionTypes; + private final Map, Type> expressionTypes; - public DesugaringRewriter(IdentityLinkedHashMap expressionTypes) + public DesugaringRewriter(Map, Type> expressionTypes) { - this.expressionTypes = requireNonNull(expressionTypes, "expressionTypes is null"); + this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null")); } @Override public Expression rewriteAtTimeZone(AtTimeZone node, Void context, ExpressionTreeRewriter treeRewriter) { Expression value = treeRewriter.rewrite(node.getValue(), context); - Type type = expressionTypes.get(node.getValue()); + Type type = getType(node.getValue()); if (type.equals(TIME)) { value = new Cast(value, TIME_WITH_TIME_ZONE.getDisplayName()); } @@ -56,4 +59,9 @@ else if (type.equals(TIMESTAMP)) { value, treeRewriter.rewrite(node.getTimeZone(), context))); } + + private Type getType(Expression expression) + { + return expressionTypes.get(NodeRef.of(expression)); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DeterminismEvaluator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DeterminismEvaluator.java index d4949ca961b26..5383f805bed2d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DeterminismEvaluator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DeterminismEvaluator.java @@ -20,6 +20,8 @@ import java.util.concurrent.atomic.AtomicBoolean; +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.BIGINT_GROUPING; +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.INTEGER_GROUPING; import static java.util.Objects.requireNonNull; /** @@ -47,7 +49,9 @@ protected Void visitFunctionCall(FunctionCall node, AtomicBoolean deterministic) // TODO: total hack to figure out if a function is deterministic. martint should fix this when he refactors the planning code if (node.getName().equals(QualifiedName.of("rand")) || node.getName().equals(QualifiedName.of("random")) || - node.getName().equals(QualifiedName.of("shuffle"))) { + node.getName().equals(QualifiedName.of("shuffle")) || + node.getName().equals(QualifiedName.of(BIGINT_GROUPING)) || + node.getName().equals(QualifiedName.of(INTEGER_GROUPING))) { deterministic.set(false); } return super.visitFunctionCall(node, deterministic); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DistributedExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DistributedExecutionPlanner.java index f63c125b3ac7f..36bb5c43004d8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DistributedExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DistributedExecutionPlanner.java @@ -117,7 +117,7 @@ private StageExecutionPlan plan(SubPlan root, Visitor visitor) } private final class Visitor - extends PlanVisitor> + extends PlanVisitor, Void> { private final Session session; private final List splitSources = new ArrayList<>(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java index c196985bb24e9..31faf21ed1886 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java @@ -43,10 +43,10 @@ import com.facebook.presto.sql.tree.IsNotNullPredicate; import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.SymbolReference; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; @@ -453,12 +453,12 @@ else if (symbolExpression instanceof Cast) { */ private Optional toNormalizedSimpleComparison(ComparisonExpression comparison) { - IdentityLinkedHashMap expressionTypes = analyzeExpression(comparison); + Map, Type> expressionTypes = analyzeExpression(comparison); Object left = ExpressionInterpreter.expressionOptimizer(comparison.getLeft(), metadata, session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE); Object right = ExpressionInterpreter.expressionOptimizer(comparison.getRight(), metadata, session, expressionTypes).optimize(NoOpSymbolResolver.INSTANCE); - Type leftType = expressionTypes.get(comparison.getLeft()); - Type rightType = expressionTypes.get(comparison.getRight()); + Type leftType = expressionTypes.get(NodeRef.of(comparison.getLeft())); + Type rightType = expressionTypes.get(NodeRef.of(comparison.getRight())); // TODO: re-enable this check once we fix the type coercions in the optimizers // checkArgument(leftType.equals(rightType), "left and right type do not match in comparison expression (%s)", comparison); @@ -488,11 +488,13 @@ private Optional toNormalizedSimpleComparison(Compar private boolean isImplicitCoercion(Cast cast) { - IdentityLinkedHashMap expressionTypes = analyzeExpression(cast); - return metadata.getTypeManager().canCoerce(expressionTypes.get(cast.getExpression()), expressionTypes.get(cast)); + Map, Type> expressionTypes = analyzeExpression(cast); + Type actualType = expressionTypes.get(NodeRef.of(cast.getExpression())); + Type expectedType = expressionTypes.get(NodeRef.of(cast)); + return metadata.getTypeManager().canCoerce(actualType, expectedType); } - private IdentityLinkedHashMap analyzeExpression(Expression expression) + private Map, Type> analyzeExpression(Expression expression) { return ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList() /* parameters already replaced */); } @@ -750,8 +752,8 @@ protected ExtractionResult visitNullLiteral(NullLiteral node, Boolean complement private static Type typeOf(Expression expression, Session session, Metadata metadata, Map types) { - IdentityLinkedHashMap expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList() /* parameters already replaced */); - return expressionTypes.get(expression); + Map, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList() /* parameters already replaced */); + return expressionTypes.get(NodeRef.of(expression)); } private static class NormalizedSimpleComparison diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java index 4532ee2116b67..6d6d8345962b8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java @@ -67,7 +67,7 @@ * Note: non-deterministic predicates can not be pulled up (so they will be ignored) */ public class EffectivePredicateExtractor - extends PlanVisitor + extends PlanVisitor { public static Expression extract(PlanNode node, Map symbolTypes) { @@ -81,6 +81,7 @@ public static Expression extract(PlanNode node, Map symbolTypes) entry -> { SymbolReference reference = entry.getKey().toSymbolReference(); Expression expression = entry.getValue(); + // TODO: this is not correct with respect to NULLs ('reference IS NULL' would be correct, rather than 'reference = NULL') // TODO: switch this to 'IS NOT DISTINCT FROM' syntax when EqualityInference properly supports it return new ComparisonExpression(ComparisonExpressionType.EQUAL, reference, expression); }; @@ -223,12 +224,9 @@ public Expression visitJoin(JoinNode node, Void context) Expression leftPredicate = node.getLeft().accept(this, context); Expression rightPredicate = node.getRight().accept(this, context); - List joinConjuncts = new ArrayList<>(); - for (JoinNode.EquiJoinClause clause : node.getCriteria()) { - joinConjuncts.add(new ComparisonExpression(ComparisonExpressionType.EQUAL, - clause.getLeft().toSymbolReference(), - clause.getRight().toSymbolReference())); - } + List joinConjuncts = node.getCriteria().stream() + .map(JoinNode.EquiJoinClause::toExpression) + .collect(toImmutableList()); switch (node.getType()) { case INNER: @@ -265,7 +263,7 @@ private static Iterable pullNullableConjunctsThroughOuterJoin(List pullExpressionThroughSymbols(expression, outputSymbols)) - .map(expression -> DependencyExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression) + .map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? TRUE_LITERAL : expression) .map(expressionOrNullSymbols(nullSymbolScopes)) .collect(toImmutableList()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java index 83b3c49674fd7..f71a6346eb8bc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EqualityInference.java @@ -65,7 +65,7 @@ public int compare(Expression expression1, Expression expression2) // 3) Sort the expressions alphabetically - creates a stable consistent ordering (extremely useful for unit testing) // TODO: be more precise in determining the cost of an expression return ComparisonChain.start() - .compare(DependencyExtractor.extractAll(expression1).size(), DependencyExtractor.extractAll(expression2).size()) + .compare(SymbolsExtractor.extractAll(expression1).size(), SymbolsExtractor.extractAll(expression2).size()) .compare(SubExpressionExtractor.extract(expression1).size(), SubExpressionExtractor.extract(expression2).size()) .compare(expression1.toString(), expression2.toString()) .result(); @@ -244,7 +244,7 @@ Expression getScopedCanonical(Expression expression, Predicate symbolSco private static Predicate symbolToExpressionPredicate(final Predicate symbolScope) { - return expression -> Iterables.all(DependencyExtractor.extractUnique(expression), symbolScope); + return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope); } /** diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java index 9064448e7bd83..863709476b089 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.FilterNode; @@ -26,19 +28,30 @@ import java.util.List; +import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; +import static java.util.Objects.requireNonNull; + public class ExpressionExtractor { public static List extractExpressions(PlanNode plan) { + return extractExpressions(plan, noLookup()); + } + + public static List extractExpressions(PlanNode plan, Lookup lookup) + { + requireNonNull(plan, "plan is null"); + requireNonNull(lookup, "lookup is null"); + ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); - plan.accept(new Visitor(true), expressionsBuilder); + plan.accept(new Visitor(true, lookup), expressionsBuilder); return expressionsBuilder.build(); } public static List extractExpressionsNonRecursive(PlanNode plan) { ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); - plan.accept(new Visitor(false), expressionsBuilder); + plan.accept(new Visitor(false, noLookup()), expressionsBuilder); return expressionsBuilder.build(); } @@ -50,10 +63,12 @@ private static class Visitor extends SimplePlanVisitor> { private final boolean recursive; + private final Lookup lookup; - public Visitor(boolean recursive) + Visitor(boolean recursive, Lookup lookup) { this.recursive = recursive; + this.lookup = requireNonNull(lookup, "lookup is null"); } @Override @@ -65,10 +80,16 @@ protected Void visitPlan(PlanNode node, ImmutableList.Builder contex return null; } + @Override + public Void visitGroupReference(GroupReference node, ImmutableList.Builder context) + { + return lookup.resolve(node).accept(this, context); + } + @Override public Void visitAggregation(AggregationNode node, ImmutableList.Builder context) { - node.getAssignments().values() + node.getAggregations().values() .forEach(aggregation -> context.add(aggregation.getCall())); return super.visitAggregation(node, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java index 10a4601a364c2..21ef526fa4cad 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java @@ -28,6 +28,9 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; +import com.facebook.presto.spi.type.RowType.RowField; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -67,6 +70,7 @@ import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullLiteral; @@ -82,21 +86,16 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.FunctionType; import com.facebook.presto.type.LikeFunctions; -import com.facebook.presto.type.RowType; -import com.facebook.presto.type.RowType.RowField; import com.facebook.presto.util.Failures; import com.facebook.presto.util.FastutilSetHelper; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Defaults; -import com.google.common.base.Functions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.primitives.Primitives; import io.airlift.joni.Regex; import io.airlift.json.JsonCodec; @@ -105,6 +104,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.util.ArrayList; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -131,6 +131,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.any; +import static java.lang.invoke.MethodHandleProxies.asInterfaceInstance; import static java.util.Objects.requireNonNull; public class ExpressionInterpreter @@ -139,15 +140,15 @@ public class ExpressionInterpreter private final Metadata metadata; private final ConnectorSession session; private final boolean optimize; - private final IdentityLinkedHashMap expressionTypes; + private final Map, Type> expressionTypes; private final Visitor visitor; // identity-based cache for LIKE expressions with constant pattern and escape char - private final IdentityLinkedHashMap likePatternCache = new IdentityLinkedHashMap<>(); - private final IdentityLinkedHashMap> inListCache = new IdentityLinkedHashMap<>(); + private final IdentityHashMap likePatternCache = new IdentityHashMap<>(); + private final IdentityHashMap> inListCache = new IdentityHashMap<>(); - public static ExpressionInterpreter expressionInterpreter(Expression expression, Metadata metadata, Session session, IdentityLinkedHashMap expressionTypes) + public static ExpressionInterpreter expressionInterpreter(Expression expression, Metadata metadata, Session session, Map, Type> expressionTypes) { requireNonNull(expression, "expression is null"); requireNonNull(metadata, "metadata is null"); @@ -156,7 +157,7 @@ public static ExpressionInterpreter expressionInterpreter(Expression expression, return new ExpressionInterpreter(expression, metadata, session, expressionTypes, false); } - public static ExpressionInterpreter expressionOptimizer(Expression expression, Metadata metadata, Session session, IdentityLinkedHashMap expressionTypes) + public static ExpressionInterpreter expressionOptimizer(Expression expression, Metadata metadata, Session session, Map, Type> expressionTypes) { requireNonNull(expression, "expression is null"); requireNonNull(metadata, "metadata is null"); @@ -170,24 +171,25 @@ public static Object evaluateConstantExpression(Expression expression, Type expe ExpressionAnalyzer analyzer = createConstantAnalyzer(metadata, session, parameters); analyzer.analyze(expression, Scope.create()); - Type actualType = analyzer.getExpressionTypes().get(expression); + Type actualType = analyzer.getExpressionTypes().get(NodeRef.of(expression)); if (!metadata.getTypeManager().canCoerce(actualType, expectedType)) { throw new SemanticException(SemanticErrorCode.TYPE_MISMATCH, expression, String.format("Cannot cast type %s to %s", expectedType.getTypeSignature(), actualType.getTypeSignature())); } - IdentityLinkedHashMap coercions = new IdentityLinkedHashMap<>(); - coercions.putAll(analyzer.getExpressionCoercions()); - coercions.put(expression, expectedType); + Map, Type> coercions = ImmutableMap., Type>builder() + .putAll(analyzer.getExpressionCoercions()) + .put(NodeRef.of(expression), expectedType) + .build(); return evaluateConstantExpression(expression, coercions, metadata, session, ImmutableSet.of(), parameters); } public static Object evaluateConstantExpression( Expression expression, - IdentityLinkedHashMap coercions, + Map, Type> coercions, Metadata metadata, Session session, - Set columnReferences, + Set> columnReferences, List parameters) { requireNonNull(columnReferences, "columnReferences is null"); @@ -203,7 +205,7 @@ public Expression rewriteExpression(Expression node, Void context, ExpressionTre Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context); // cast expression if coercion is registered - Type coerceToType = coercions.get(node); + Type coerceToType = coercions.get(NodeRef.of(node)); if (coerceToType != null) { rewrittenExpression = new Cast(rewrittenExpression, coerceToType.getTypeSignature().toString()); @@ -235,18 +237,18 @@ public Expression rewriteExpression(Expression node, Void context, ExpressionTre return result; } - public static void verifyExpressionIsConstant(Set columnReferences, Expression expression) + public static void verifyExpressionIsConstant(Set> columnReferences, Expression expression) { new ConstantExpressionVerifierVisitor(columnReferences, expression).process(expression, null); } - private ExpressionInterpreter(Expression expression, Metadata metadata, Session session, IdentityLinkedHashMap expressionTypes, boolean optimize) + private ExpressionInterpreter(Expression expression, Metadata metadata, Session session, Map, Type> expressionTypes, boolean optimize) { this.expression = expression; this.metadata = metadata; this.session = session.toConnectorSession(); - this.expressionTypes = expressionTypes; - verify((expressionTypes.containsKey(expression))); + this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null")); + verify((expressionTypes.containsKey(NodeRef.of(expression)))); this.optimize = optimize; this.visitor = new Visitor(); @@ -254,7 +256,7 @@ private ExpressionInterpreter(Expression expression, Metadata metadata, Session public Type getType() { - return expressionTypes.get(expression); + return expressionTypes.get(NodeRef.of(expression)); } public Object evaluate(RecordCursor inputs) @@ -284,10 +286,10 @@ public Object optimize(SymbolResolver inputs) private static class ConstantExpressionVerifierVisitor extends DefaultTraversalVisitor { - private final Set columnReferences; + private final Set> columnReferences; private final Expression expression; - public ConstantExpressionVerifierVisitor(Set columnReferences, Expression expression) + public ConstantExpressionVerifierVisitor(Set> columnReferences, Expression expression) { this.columnReferences = columnReferences; this.expression = expression; @@ -296,7 +298,7 @@ public ConstantExpressionVerifierVisitor(Set columnReferences, Expre @Override protected Void visitDereferenceExpression(DereferenceExpression node, Void context) { - if (columnReferences.contains(node)) { + if (columnReferences.contains(NodeRef.of(node))) { throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references"); } @@ -324,7 +326,7 @@ private class Visitor @Override public Object visitFieldReference(FieldReference node, Object context) { - Type type = expressionTypes.get(node); + Type type = type(node); int channel = node.getFieldIndex(); if (context instanceof PagePositionContext) { @@ -388,7 +390,7 @@ else if (javaType == Block.class) { @Override protected Object visitDereferenceExpression(DereferenceExpression node, Object context) { - Type type = expressionTypes.get(node.getBase()); + Type type = type(node.getBase()); // if there is no type for the base of Dereference, it must be QualifiedName if (type == null) { return node; @@ -406,7 +408,7 @@ protected Object visitDereferenceExpression(DereferenceExpression node, Object c RowType rowType = (RowType) type; Block row = (Block) base; - Type returnType = expressionTypes.get(node); + Type returnType = type(node); List fields = rowType.getFields(); int index = -1; for (int i = 0; i < fields.size(); i++) { @@ -469,7 +471,7 @@ protected Object visitIsNullPredicate(IsNullPredicate node, Object context) Object value = process(node.getValue(), context); if (value instanceof Expression) { - return new IsNullPredicate(toExpression(value, expressionTypes.get(node.getValue()))); + return new IsNullPredicate(toExpression(value, type(node.getValue()))); } return value == null; @@ -481,7 +483,7 @@ protected Object visitIsNotNullPredicate(IsNotNullPredicate node, Object context Object value = process(node.getValue(), context); if (value instanceof Expression) { - return new IsNotNullPredicate(toExpression(value, expressionTypes.get(node.getValue()))); + return new IsNotNullPredicate(toExpression(value, type(node.getValue()))); } return value != null; @@ -606,7 +608,7 @@ private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2 private Type type(Expression expression) { - return expressionTypes.get(expression); + return expressionTypes.get(NodeRef.of(expression)); } @Override @@ -658,7 +660,7 @@ protected Object visitInPredicate(InPredicate node, Object context) if (valueList.getValues().stream().allMatch(Literal.class::isInstance) && valueList.getValues().stream().noneMatch(NullLiteral.class::isInstance)) { Set objectSet = valueList.getValues().stream().map(expression -> process(expression, context)).collect(Collectors.toSet()); - set = FastutilSetHelper.toFastutilHashSet(objectSet, expressionTypes.get(node.getValue()), metadata.getFunctionRegistry()); + set = FastutilSetHelper.toFastutilHashSet(objectSet, type(node.getValue()), metadata.getFunctionRegistry()); } inListCache.put(valueList, set); } @@ -681,7 +683,7 @@ protected Object visitInPredicate(InPredicate node, Object context) if (value instanceof Expression || inValue instanceof Expression) { hasUnresolvedValue = true; values.add(inValue); - types.add(expressionTypes.get(expression)); + types.add(type(expression)); continue; } @@ -698,7 +700,7 @@ else if (!found && (Boolean) invokeOperator(OperatorType.EQUAL, types(node.getVa } if (hasUnresolvedValue) { - Type type = expressionTypes.get(node.getValue()); + Type type = type(node.getValue()); List expressionValues = toExpressions(values, types); List simplifiedExpressionValues = Stream.concat( expressionValues.stream() @@ -741,7 +743,7 @@ protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object con return null; } if (value instanceof Expression) { - return new ArithmeticUnaryExpression(node.getSign(), toExpression(value, expressionTypes.get(node.getValue()))); + return new ArithmeticUnaryExpression(node.getSign(), toExpression(value, type(node.getValue()))); } switch (node.getSign()) { @@ -780,7 +782,7 @@ protected Object visitArithmeticBinary(ArithmeticBinaryExpression node, Object c } if (hasUnresolvedValue(left, right)) { - return new ArithmeticBinaryExpression(node.getType(), toExpression(left, expressionTypes.get(node.getLeft())), toExpression(right, expressionTypes.get(node.getRight()))); + return new ArithmeticBinaryExpression(node.getType(), toExpression(left, type(node.getLeft())), toExpression(right, type(node.getRight()))); } return invokeOperator(OperatorType.valueOf(node.getType().name()), types(node.getLeft(), node.getRight()), ImmutableList.of(left, right)); @@ -810,7 +812,7 @@ else if (right == null) { } if (hasUnresolvedValue(left, right)) { - return new ComparisonExpression(type, toExpression(left, expressionTypes.get(node.getLeft())), toExpression(right, expressionTypes.get(node.getRight()))); + return new ComparisonExpression(type, toExpression(left, type(node.getLeft())), toExpression(right, type(node.getRight()))); } return invokeOperator(OperatorType.valueOf(type.name()), types(node.getLeft(), node.getRight()), ImmutableList.of(left, right)); @@ -834,9 +836,9 @@ protected Object visitBetweenPredicate(BetweenPredicate node, Object context) if (hasUnresolvedValue(value, min, max)) { return new BetweenPredicate( - toExpression(value, expressionTypes.get(node.getValue())), - toExpression(min, expressionTypes.get(node.getMin())), - toExpression(max, expressionTypes.get(node.getMax()))); + toExpression(value, type(node.getValue())), + toExpression(min, type(node.getMin())), + toExpression(max, type(node.getMax()))); } return invokeOperator(OperatorType.BETWEEN, types(node.getValue(), node.getMin(), node.getMax()), ImmutableList.of(value, min, max)); @@ -854,8 +856,8 @@ protected Object visitNullIfExpression(NullIfExpression node, Object context) return first; } - Type firstType = expressionTypes.get(node.getFirst()); - Type secondType = expressionTypes.get(node.getSecond()); + Type firstType = type(node.getFirst()); + Type secondType = type(node.getSecond()); if (hasUnresolvedValue(first, second)) { return new NullIfExpression(toExpression(first, firstType), toExpression(second, secondType)); @@ -893,7 +895,7 @@ protected Object visitNotExpression(NotExpression node, Object context) } if (value instanceof Expression) { - return new NotExpression(toExpression(value, expressionTypes.get(node.getValue()))); + return new NotExpression(toExpression(value, type(node.getValue()))); } return !(Boolean) value; @@ -935,8 +937,8 @@ protected Object visitLogicalBinaryExpression(LogicalBinaryExpression node, Obje } return new LogicalBinaryExpression(node.getType(), - toExpression(left, expressionTypes.get(node.getLeft())), - toExpression(right, expressionTypes.get(node.getRight()))); + toExpression(left, type(node.getLeft())), + toExpression(right, type(node.getRight()))); } @Override @@ -952,7 +954,7 @@ protected Object visitFunctionCall(FunctionCall node, Object context) List argumentValues = new ArrayList<>(); for (Expression expression : node.getArguments()) { Object value = process(expression, context); - Type type = expressionTypes.get(expression); + Type type = type(expression); argumentValues.add(value); argumentTypes.add(type); } @@ -985,7 +987,7 @@ protected Object visitLambdaExpression(LambdaExpression node, Object context) List argumentNames = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .collect(toImmutableList()); - FunctionType functionType = (FunctionType) expressionTypes.get(node); + FunctionType functionType = (FunctionType) expressionTypes.get(NodeRef.of(node)); checkArgument(argumentNames.size() == functionType.getArgumentTypes().size()); return generateVarArgsToMapAdapter( @@ -1001,16 +1003,23 @@ protected Object visitLambdaExpression(LambdaExpression node, Object context) @Override protected Object visitBindExpression(BindExpression node, Object context) { - Object value = process(node.getValue(), context); + List values = node.getValues().stream() + .map(value -> process(value, context)) + .collect(toImmutableList()); Object function = process(node.getFunction(), context); - if (hasUnresolvedValue(value, function)) { + if (hasUnresolvedValue(values) || hasUnresolvedValue(function)) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < values.size(); i++) { + builder.add(toExpression(values.get(i), type(node.getValues().get(i)))); + } + return new BindExpression( - toExpression(value, expressionTypes.get(node.getValue())), - toExpression(function, expressionTypes.get(node.getFunction()))); + builder.build(), + toExpression(function, type(node.getFunction()))); } - return MethodHandles.insertArguments((MethodHandle) function, 0, value); + return MethodHandles.insertArguments((MethodHandle) function, 0, values.toArray()); } @Override @@ -1061,7 +1070,7 @@ protected Object visitLikePredicate(LikePredicate node, Object context) // if pattern is a constant without % or _ replace with a comparison if (pattern instanceof Slice && (escape == null || escape instanceof Slice) && !isLikePattern((Slice) pattern, (Slice) escape)) { Slice unescapedPattern = unescapeLiteralLikePattern((Slice) pattern, (Slice) escape); - Type valueType = expressionTypes.get(node.getValue()); + Type valueType = type(node.getValue()); Type patternType = createVarcharType(unescapedPattern.length()); TypeManager typeManager = metadata.getTypeManager(); Optional commonSuperType = typeManager.getCommonSuperType(valueType, patternType); @@ -1080,12 +1089,12 @@ protected Object visitLikePredicate(LikePredicate node, Object context) Expression optimizedEscape = null; if (node.getEscape() != null) { - optimizedEscape = toExpression(escape, expressionTypes.get(node.getEscape())); + optimizedEscape = toExpression(escape, type(node.getEscape())); } return new LikePredicate( - toExpression(value, expressionTypes.get(node.getValue())), - toExpression(pattern, expressionTypes.get(node.getPattern())), + toExpression(value, type(node.getValue())), + toExpression(pattern, type(node.getPattern())), optimizedEscape); } @@ -1142,8 +1151,8 @@ public Object visitCast(Cast node, Object context) // hack!!! don't optimize CASTs for types that cannot be represented in the SQL AST // TODO: this will not be an issue when we migrate to RowExpression tree for this, which allows arbitrary literals. - if (optimize && !FunctionRegistry.isSupportedLiteralType(expressionTypes.get(node))) { - return new Cast(toExpression(value, expressionTypes.get(node.getExpression())), node.getType(), node.isSafe(), node.isTypeOnly()); + if (optimize && !FunctionRegistry.isSupportedLiteralType(type(node))) { + return new Cast(toExpression(value, type(node.getExpression())), node.getType(), node.isSafe(), node.isTypeOnly()); } if (value == null) { @@ -1155,7 +1164,7 @@ public Object visitCast(Cast node, Object context) throw new IllegalArgumentException("Unsupported type: " + node.getType()); } - Signature operator = metadata.getFunctionRegistry().getCoercion(expressionTypes.get(node.getExpression()), type); + Signature operator = metadata.getFunctionRegistry().getCoercion(type(node.getExpression()), type); try { return invoke(session, metadata.getFunctionRegistry().getScalarFunctionImplementation(operator), ImmutableList.of(value)); @@ -1171,7 +1180,7 @@ public Object visitCast(Cast node, Object context) @Override protected Object visitArrayConstructor(ArrayConstructor node, Object context) { - Type elementType = ((ArrayType) expressionTypes.get(node)).getElementType(); + Type elementType = ((ArrayType) type(node)).getElementType(); BlockBuilder arrayBlockBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), node.getValues().size()); for (Expression expression : node.getValues()) { @@ -1188,7 +1197,7 @@ protected Object visitArrayConstructor(ArrayConstructor node, Object context) @Override protected Object visitRow(Row node, Object context) { - RowType rowType = (RowType) expressionTypes.get(node); + RowType rowType = (RowType) type(node); List parameterTypes = rowType.getTypeParameters(); List arguments = node.getItems(); @@ -1220,12 +1229,12 @@ protected Object visitSubscriptExpression(SubscriptExpression node, Object conte if (index == null) { return null; } - if ((index instanceof Long) && isArray(expressionTypes.get(node.getBase()))) { + if ((index instanceof Long) && isArray(type(node.getBase()))) { ArraySubscriptOperator.checkArrayIndex((Long) index); } if (hasUnresolvedValue(base, index)) { - return new SubscriptExpression(toExpression(base, expressionTypes.get(node.getBase())), toExpression(index, expressionTypes.get(node.getIndex()))); + return new SubscriptExpression(toExpression(base, type(node.getBase())), toExpression(index, type(node.getIndex()))); } return invokeOperator(OperatorType.SUBSCRIPT, types(node.getBase(), node.getIndex()), ImmutableList.of(base, index)); @@ -1254,7 +1263,10 @@ protected Object visitNode(Node node, Object context) private List types(Expression... types) { - return ImmutableList.copyOf(Iterables.transform(ImmutableList.copyOf(types), Functions.forMap(expressionTypes))); + return Stream.of(types) + .map(NodeRef::of) + .map(expressionTypes::get) + .collect(toImmutableList()); } private boolean hasUnresolvedValue(Object... values) @@ -1367,6 +1379,9 @@ public static Object invoke(ConnectorSession session, ScalarFunctionImplementati Class[] parameterArray = handle.type().parameterArray(); for (int i = 0; i < argumentValues.size(); i++) { Object argument = argumentValues.get(i); + if (function.getLambdaInterface().get(i).isPresent() && !MethodHandle.class.equals(function.getLambdaInterface().get(i).get())) { + argument = asInterfaceInstance(function.getLambdaInterface().get(i).get(), (MethodHandle) argument); + } if (function.getNullFlags().get(i)) { boolean isNull = argument == null; if (isNull) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/FragmentTableScanCounter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/FragmentTableScanCounter.java index 9bbaf6955c3ba..f6bf9949a66d9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/FragmentTableScanCounter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/FragmentTableScanCounter.java @@ -49,7 +49,7 @@ public static boolean hasMultipleSources(PlanNode... nodes) } private static class Visitor - extends PlanVisitor + extends PlanVisitor { @Override public Integer visitTableScan(TableScanNode node, Void context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/GroupingOperationRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/GroupingOperationRewriter.java new file mode 100644 index 0000000000000..e3d613c0e8fd9 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/GroupingOperationRewriter.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.sql.analyzer.Analysis; +import com.facebook.presto.sql.analyzer.FieldId; +import com.facebook.presto.sql.analyzer.RelationId; +import com.facebook.presto.sql.analyzer.TypeSignatureProvider; +import com.facebook.presto.sql.tree.ArrayConstructor; +import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GenericLiteral; +import com.facebook.presto.sql.tree.GroupingOperation; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.QuerySpecification; +import com.facebook.presto.type.ListLiteralType; +import com.google.common.collect.ImmutableList; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.BIGINT_GROUPING; +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.INTEGER_GROUPING; +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.MAX_NUMBER_GROUPING_ARGUMENTS_INTEGER; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.resolveFunction; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public final class GroupingOperationRewriter +{ + private GroupingOperationRewriter() {} + + public static Expression rewriteGroupingOperation(GroupingOperation expression, QuerySpecification queryNode, Analysis analysis, Metadata metadata, Optional groupIdSymbol) + { + requireNonNull(queryNode, "node is null"); + requireNonNull(analysis, "analysis is null"); + requireNonNull(metadata, "metadata is null"); + requireNonNull(groupIdSymbol, "groupIdSymbol is null"); + + checkState(queryNode.getGroupBy().isPresent(), "GroupBy node must be present"); + + // No GroupIdNode and a GROUPING() operation imply a single grouping, which + // means that any columns specified as arguments to GROUPING() will be included + // in the group and none of them will be aggregated over. Hence, re-write the + // GroupingOperation to a constant literal of 0. + // See SQL:2011:4.16.2 and SQL:2011:6.9.10. + if (analysis.getGroupingSets(queryNode).size() == 1) { + if (shouldUseIntegerReturnType(expression)) { + return new LongLiteral("0"); + } + else { + return new GenericLiteral(StandardTypes.BIGINT, "0"); + } + } + else { + checkState(groupIdSymbol.isPresent(), "groupId symbol is missing"); + + Map, FieldId> columnReferenceFields = analysis.getColumnReferenceFields(); + RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getRelationId(); + List groupingOrdinals = expression.getGroupingColumns().stream() + .map(NodeRef::of) + .peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")) + .map(columnReferenceFields::get) + .map(fieldId -> translateFieldToLongLiteral(fieldId, relationId)) + .collect(toImmutableList()); + + List> groupingSetOrdinals = analysis.getGroupingSets(queryNode).stream() + .map(groupingSet -> groupingSet.stream() + .map(NodeRef::of) + .filter(columnReferenceFields::containsKey) + .map(columnReferenceFields::get) + .map(fieldId -> translateFieldToLongLiteral(fieldId, relationId)) + .collect(toImmutableList())) + .collect(toImmutableList()); + + List newGroupingArguments = ImmutableList.of( + groupIdSymbol.get().toSymbolReference(), + new Cast(new ArrayConstructor(groupingOrdinals), ListLiteralType.NAME), + new Cast(new ArrayConstructor(groupingSetOrdinals.stream().map(ArrayConstructor::new).collect(toImmutableList())), ListLiteralType.NAME) + ); + + FunctionCall rewritten = new FunctionCall( + expression.getLocation().get(), + shouldUseIntegerReturnType(expression) ? QualifiedName.of(INTEGER_GROUPING) : QualifiedName.of(BIGINT_GROUPING), + newGroupingArguments); + List functionArgumentTypes = Arrays.asList( + new TypeSignatureProvider(BIGINT.getTypeSignature()), + new TypeSignatureProvider(ListLiteralType.LIST_LITERAL.getTypeSignature()), + new TypeSignatureProvider(ListLiteralType.LIST_LITERAL.getTypeSignature()) + ); + resolveFunction(rewritten, functionArgumentTypes, metadata.getFunctionRegistry()); + + return rewritten; + } + } + + private static Expression translateFieldToLongLiteral(FieldId fieldId, RelationId requiredOriginRelationId) + { + // TODO: this section should be rewritten when support is added for GROUP BY columns to reference an outer scope + checkState(fieldId.getRelationId().equals(requiredOriginRelationId), "grouping arguments must all come from the same relation"); + return new LongLiteral(Integer.toString(fieldId.getFieldIndex())); + } + + private static boolean shouldUseIntegerReturnType(GroupingOperation groupingOperation) + { + return groupingOperation.getGroupingColumns().size() <= MAX_NUMBER_GROUPING_ARGUMENTS_INTEGER; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LambdaCaptureDesugaringRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LambdaCaptureDesugaringRewriter.java index 597293462e0c1..830e46df58c69 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LambdaCaptureDesugaringRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LambdaCaptureDesugaringRewriter.java @@ -98,8 +98,12 @@ public Expression rewriteLambdaExpression(LambdaExpression node, Context context } newLambdaArguments.addAll(node.getArguments()); Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), replaceSymbols(rewrittenBody, captureSymbolToExtraSymbol.build())); - for (Symbol captureSymbol : captureSymbols) { - rewrittenExpression = new BindExpression(new SymbolReference(captureSymbol.getName()), rewrittenExpression); + + if (captureSymbols.size() != 0) { + List capturedValues = captureSymbols.stream() + .map(symbol -> new SymbolReference(symbol.getName())) + .collect(toImmutableList()); + rewrittenExpression = new BindExpression(capturedValues, rewrittenExpression); } context.getReferencedSymbols().addAll(captureSymbols); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java index 8f8b61eb2c2fd..a17813f551b49 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java @@ -75,6 +75,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public final class LiteralInterpreter @@ -193,7 +194,7 @@ else if (value.equals(Float.POSITIVE_INFINITY)) { } if (object instanceof Block) { - SliceOutput output = new DynamicSliceOutput(((Block) object).getSizeInBytes()); + SliceOutput output = new DynamicSliceOutput(toIntExact(((Block) object).getSizeInBytes())); BlockSerdeUtil.writeBlock(output, (Block) object); object = output.slice(); // This if condition will evaluate to true: object instanceof Slice && !type.equals(VARCHAR) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 729cda50e9d9d..5aaebc24afa1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.QueryPerformanceFetcher; import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.buffer.OutputBuffer; @@ -104,6 +105,7 @@ import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; @@ -144,7 +146,7 @@ import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -159,7 +161,6 @@ import io.airlift.log.Logger; import io.airlift.units.DataSize; -import javax.annotation.Nullable; import javax.inject.Inject; import java.util.ArrayList; @@ -226,6 +227,7 @@ public class LocalExecutionPlanner private final Metadata metadata; private final SqlParser sqlParser; + private final CostCalculator costCalculator; private final Optional queryPerformanceFetcher; private final PageSourceProvider pageSourceProvider; @@ -250,6 +252,7 @@ public class LocalExecutionPlanner public LocalExecutionPlanner( Metadata metadata, SqlParser sqlParser, + CostCalculator costCalculator, Optional queryPerformanceFetcher, PageSourceProvider pageSourceProvider, IndexManager indexManager, @@ -275,6 +278,7 @@ public LocalExecutionPlanner( this.exchangeClientSupplier = exchangeClientSupplier; this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "compiler is null"); this.joinFilterFunctionCompiler = requireNonNull(joinFilterFunctionCompiler, "compiler is null"); @@ -343,8 +347,8 @@ public LocalExecutionPlan plan( Set partitioningColumns = partitioningScheme.getPartitioning().getColumns(); // partitioningColumns expected to have one column in the normal case, and zero columns when partitioning on a constant - checkArgument(!partitioningScheme.isReplicateNulls() || partitioningColumns.size() <= 1); - if (partitioningScheme.isReplicateNulls() && partitioningColumns.size() == 1) { + checkArgument(!partitioningScheme.isReplicateNullsAndAny() || partitioningColumns.size() <= 1); + if (partitioningScheme.isReplicateNullsAndAny() && partitioningColumns.size() == 1) { nullChannel = OptionalInt.of(outputLayout.indexOf(getOnlyElement(partitioningColumns))); } @@ -353,7 +357,14 @@ public LocalExecutionPlan plan( plan, outputLayout, types, - new PartitionedOutputFactory(partitionFunction, partitionChannels, partitionConstants, nullChannel, outputBuffer, maxPagePartitioningBufferSize)); + new PartitionedOutputFactory( + partitionFunction, + partitionChannels, + partitionConstants, + partitioningScheme.isReplicateNullsAndAny(), + nullChannel, + outputBuffer, + maxPagePartitioningBufferSize)); } public LocalExecutionPlan plan(Session session, @@ -563,7 +574,7 @@ public List getDriverFactories() } private class Visitor - extends PlanVisitor + extends PlanVisitor { private final Session session; @@ -597,7 +608,7 @@ public PhysicalOperation visitExplainAnalyze(ExplainAnalyzeNode node, LocalExecu checkState(queryPerformanceFetcher.isPresent(), "ExplainAnalyze can only run on coordinator"); PhysicalOperation source = node.getSource().accept(this, context); - OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory(context.getNextOperatorId(), node.getId(), queryPerformanceFetcher.get(), metadata); + OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory(context.getNextOperatorId(), node.getId(), queryPerformanceFetcher.get(), metadata, costCalculator); return new PhysicalOperation(operatorFactory, makeLayout(node), source); } @@ -801,7 +812,7 @@ public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext cont (int) node.getCount(), sortChannels, sortOrders, - node.isPartial(), + node.getStep().equals(TopNNode.Step.PARTIAL), maxPartialAggregationMemorySize); return new PhysicalOperation(operator, source.getLayout(), source); @@ -1043,7 +1054,7 @@ private PhysicalOperation visitScanFilterAndProject( rewrittenProjections.add(symbolToInputRewriter.rewrite(assignments.get(symbol))); } - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypesFromInput( context.getSession(), metadata, sqlParser, @@ -1069,7 +1080,7 @@ private PhysicalOperation visitScanFilterAndProject( cursorProcessor, pageProcessor, columns, - Lists.transform(rewrittenProjections, forMap(expressionTypes))); + getTypes(rewrittenProjections, expressionTypes)); return new PhysicalOperation(operatorFactory, outputMappings); } @@ -1080,7 +1091,7 @@ private PhysicalOperation visitScanFilterAndProject( context.getNextOperatorId(), planNodeId, pageProcessor, - Lists.transform(rewrittenProjections, forMap(expressionTypes))); + getTypes(rewrittenProjections, expressionTypes)); return new PhysicalOperation(operatorFactory, outputMappings, source); } @@ -1121,7 +1132,7 @@ private PhysicalOperation visitScanFilterAndProject( () -> cursorProcessor, () -> pageProcessor, columns, - Lists.transform(rewrittenProjections, forMap(expressionTypes))); + getTypes(rewrittenProjections, expressionTypes)); return new PhysicalOperation(operatorFactory, outputMappings); } @@ -1130,7 +1141,7 @@ private PhysicalOperation visitScanFilterAndProject( context.getNextOperatorId(), planNodeId, () -> pageProcessor, - Lists.transform(rewrittenProjections, forMap(expressionTypes))); + getTypes(rewrittenProjections, expressionTypes)); return new PhysicalOperation(operatorFactory, outputMappings, source); } } @@ -1150,7 +1161,7 @@ private PageProcessor createInterpretedColumnarPageProcessor( return new PageProcessor(pageFilter, pageProjections); } - private RowExpression toRowExpression(Expression expression, IdentityLinkedHashMap types) + private RowExpression toRowExpression(Expression expression, Map, Type> types) { return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true); } @@ -1199,7 +1210,7 @@ public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext PageBuilder pageBuilder = new PageBuilder(outputTypes); for (List row : node.getRows()) { pageBuilder.declarePosition(); - IdentityLinkedHashMap expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = getExpressionTypes( context.getSession(), metadata, sqlParser, @@ -1561,7 +1572,13 @@ private LookupSourceFactory createLookupSourceFactory( Optional buildHashChannel = buildHashSymbol.map(channelGetter(buildSource)); Optional filterFunctionFactory = node.getFilter() - .map(filterExpression -> compileJoinFilterFunction(filterExpression, probeLayout, buildSource.getLayout(), context.getTypes(), context.getSession())); + .map(filterExpression -> compileJoinFilterFunction( + filterExpression, + node.getSortExpression(), + probeLayout, + buildSource.getLayout(), + context.getTypes(), + context.getSession())); HashBuilderOperatorFactory hashBuilderOperatorFactory = new HashBuilderOperatorFactory( buildContext.getNextOperatorId(), @@ -1591,6 +1608,7 @@ private LookupSourceFactory createLookupSourceFactory( private JoinFilterFunctionFactory compileJoinFilterFunction( Expression filterExpression, + Optional sortExpression, Map probeLayout, Map buildLayout, Map types, @@ -1602,10 +1620,12 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( .collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey()))); Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression); + Optional rewrittenSortExpression = sortExpression.map( + expression -> new SymbolToInputRewriter(buildLayout).rewrite(expression)); - Optional sortChannel = SortExpressionExtractor.extractSortExpression(buildLayout, rewrittenFilter); + Optional sortChannel = rewrittenSortExpression.map(SortExpression::fromExpression); - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypesFromInput( session, metadata, sqlParser, @@ -1871,7 +1891,7 @@ private AccumulatorFactory buildAccumulatorFactory( PhysicalOperation source, Signature function, FunctionCall call, - @Nullable Symbol mask) + Optional mask) { List arguments = new ArrayList<>(); for (Expression argument : call.getArguments()) { @@ -1881,7 +1901,7 @@ private AccumulatorFactory buildAccumulatorFactory( Optional maskChannel = Optional.empty(); if (mask != null) { - maskChannel = Optional.of(source.getLayout().get(mask)); + maskChannel = mask.map(value -> source.getLayout().get(value)); } return metadata.getFunctionRegistry().getAggregateFunctionImplementation(function).bind(arguments, maskChannel); @@ -1892,13 +1912,13 @@ private PhysicalOperation planGlobalAggregation(int operatorId, AggregationNode int outputChannel = 0; ImmutableMap.Builder outputMappings = ImmutableMap.builder(); List accumulatorFactories = new ArrayList<>(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + for (Map.Entry entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); - + Aggregation aggregation = entry.getValue(); accumulatorFactories.add(buildAccumulatorFactory(source, - node.getFunctions().get(symbol), - entry.getValue(), - node.getMasks().get(entry.getKey()))); + aggregation.getSignature(), + aggregation.getCall(), + aggregation.getMask())); outputMappings.put(symbol, outputChannel); // one aggregation per channel outputChannel++; } @@ -1918,14 +1938,15 @@ private PhysicalOperation planGroupByAggregation( List aggregationOutputSymbols = new ArrayList<>(); List accumulatorFactories = new ArrayList<>(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + for (Map.Entry entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); + Aggregation aggregation = entry.getValue(); accumulatorFactories.add(buildAccumulatorFactory( source, - node.getFunctions().get(symbol), - entry.getValue(), - node.getMasks().get(entry.getKey()))); + aggregation.getSignature(), + aggregation.getCall(), + aggregation.getMask())); aggregationOutputSymbols.add(symbol); } @@ -1985,6 +2006,14 @@ private PhysicalOperation planGroupByAggregation( } } + private static List getTypes(List expressions, Map, Type> expressionTypes) + { + return expressions.stream() + .map(NodeRef::of) + .map(expressionTypes::get) + .collect(toImmutableList()); + } + private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata) { WriterTarget target = node.getTarget(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index a342c8f08dff4..8634cfc6ce685 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -15,6 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.NewTableLayout; import com.facebook.presto.metadata.QualifiedObjectName; @@ -37,6 +39,7 @@ import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; @@ -50,14 +53,15 @@ import com.facebook.presto.sql.tree.Insert; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.Statement; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -88,24 +92,28 @@ public enum Stage private final SymbolAllocator symbolAllocator = new SymbolAllocator(); private final Metadata metadata; private final SqlParser sqlParser; + private final CostCalculator costCalculator; public LogicalPlanner(Session session, List planOptimizers, PlanNodeIdAllocator idAllocator, Metadata metadata, - SqlParser sqlParser) + SqlParser sqlParser, + CostCalculator costCalculator) { requireNonNull(session, "session is null"); requireNonNull(planOptimizers, "planOptimizers is null"); requireNonNull(idAllocator, "idAllocator is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(sqlParser, "sqlParser is null"); + requireNonNull(costCalculator, "costCalculator is null"); this.session = session; this.planOptimizers = planOptimizers; this.idAllocator = idAllocator; this.metadata = metadata; this.sqlParser = sqlParser; + this.costCalculator = costCalculator; } public Plan plan(Analysis analysis) @@ -117,6 +125,8 @@ public Plan plan(Analysis analysis, Stage stage) { PlanNode root = planStatement(analysis, analysis.getStatement()); + PlanSanityChecker.validateIntermediatePlan(root, session, metadata, sqlParser, symbolAllocator.getTypes()); + if (stage.ordinal() >= Stage.OPTIMIZED.ordinal()) { for (PlanOptimizer optimizer : planOptimizers) { root = optimizer.optimize(root, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator); @@ -126,10 +136,12 @@ public Plan plan(Analysis analysis, Stage stage) if (stage.ordinal() >= Stage.OPTIMIZED_AND_VALIDATED.ordinal()) { // make sure we produce a valid plan after optimizations run. This is mainly to catch programming errors - PlanSanityChecker.validate(root, session, metadata, sqlParser, symbolAllocator.getTypes()); + PlanSanityChecker.validateFinalPlan(root, session, metadata, sqlParser, symbolAllocator.getTypes()); } - return new Plan(root, symbolAllocator.getTypes()); + Map planNodeCosts = costCalculator.calculateCostForPlan(session, symbolAllocator.getTypes(), root); + + return new Plan(root, symbolAllocator.getTypes(), planNodeCosts); } public PlanNode planStatement(Analysis analysis, Statement statement) @@ -383,18 +395,18 @@ private static List getOutputTableColumns(RelationPlan plan) return columns.build(); } - private static IdentityLinkedHashMap buildLambdaDeclarationToSymbolMap(Analysis analysis, SymbolAllocator symbolAllocator) + private static Map, Symbol> buildLambdaDeclarationToSymbolMap(Analysis analysis, SymbolAllocator symbolAllocator) { - IdentityLinkedHashMap resultMap = new IdentityLinkedHashMap<>(); - for (Map.Entry entry : analysis.getTypes().entrySet()) { - if (!(entry.getKey() instanceof LambdaArgumentDeclaration)) { + Map, Symbol> resultMap = new LinkedHashMap<>(); + for (Map.Entry, Type> entry : analysis.getTypes().entrySet()) { + if (!(entry.getKey().getNode() instanceof LambdaArgumentDeclaration)) { continue; } - LambdaArgumentDeclaration lambdaArgumentDeclaration = (LambdaArgumentDeclaration) entry.getKey(); + NodeRef lambdaArgumentDeclaration = NodeRef.of((LambdaArgumentDeclaration) entry.getKey().getNode()); if (resultMap.containsKey(lambdaArgumentDeclaration)) { continue; } - resultMap.put(lambdaArgumentDeclaration, symbolAllocator.newSymbol(lambdaArgumentDeclaration, entry.getValue())); + resultMap.put(lambdaArgumentDeclaration, symbolAllocator.newSymbol(lambdaArgumentDeclaration.getNode(), entry.getValue())); } return resultMap; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java index e331081b0a92b..570ce889baceb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PartitioningScheme.java @@ -32,7 +32,7 @@ public class PartitioningScheme private final Partitioning partitioning; private final List outputLayout; private final Optional hashColumn; - private final boolean replicateNulls; + private final boolean replicateNullsAndAny; private final Optional bucketToPartition; public PartitioningScheme(Partitioning partitioning, List outputLayout) @@ -60,7 +60,7 @@ public PartitioningScheme( @JsonProperty("partitioning") Partitioning partitioning, @JsonProperty("outputLayout") List outputLayout, @JsonProperty("hashColumn") Optional hashColumn, - @JsonProperty("replicateNulls") boolean replicateNulls, + @JsonProperty("replicateNullsAndAny") boolean replicateNullsAndAny, @JsonProperty("bucketToPartition") Optional bucketToPartition) { this.partitioning = requireNonNull(partitioning, "partitioning is null"); @@ -74,8 +74,8 @@ public PartitioningScheme( hashColumn.ifPresent(column -> checkArgument(outputLayout.contains(column), "Output layout (%s) don't include hash column (%s)", outputLayout, column)); - checkArgument(!replicateNulls || columns.size() <= 1, "Must have at most one partitioning column when nullPartition is REPLICATE."); - this.replicateNulls = replicateNulls; + checkArgument(!replicateNullsAndAny || columns.size() <= 1, "Must have at most one partitioning column when nullPartition is REPLICATE."); + this.replicateNullsAndAny = replicateNullsAndAny; this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); } @@ -98,9 +98,9 @@ public Optional getHashColumn() } @JsonProperty - public boolean isReplicateNulls() + public boolean isReplicateNullsAndAny() { - return replicateNulls; + return replicateNullsAndAny; } @JsonProperty @@ -111,7 +111,7 @@ public Optional getBucketToPartition() public PartitioningScheme withBucketToPartition(Optional bucketToPartition) { - return new PartitioningScheme(partitioning, outputLayout, hashColumn, replicateNulls, bucketToPartition); + return new PartitioningScheme(partitioning, outputLayout, hashColumn, replicateNullsAndAny, bucketToPartition); } public PartitioningScheme translateOutputLayout(List newOutputLayout) @@ -126,7 +126,7 @@ public PartitioningScheme translateOutputLayout(List newOutputLayout) .map(outputLayout::indexOf) .map(newOutputLayout::get); - return new PartitioningScheme(newPartitioning, newOutputLayout, newHashSymbol, replicateNulls, bucketToPartition); + return new PartitioningScheme(newPartitioning, newOutputLayout, newHashSymbol, replicateNullsAndAny, bucketToPartition); } @Override @@ -141,14 +141,14 @@ public boolean equals(Object o) PartitioningScheme that = (PartitioningScheme) o; return Objects.equals(partitioning, that.partitioning) && Objects.equals(outputLayout, that.outputLayout) && - replicateNulls == that.replicateNulls && + replicateNullsAndAny == that.replicateNullsAndAny && Objects.equals(bucketToPartition, that.bucketToPartition); } @Override public int hashCode() { - return Objects.hash(partitioning, outputLayout, replicateNulls, bucketToPartition); + return Objects.hash(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition); } @Override @@ -158,7 +158,7 @@ public String toString() .add("partitioning", partitioning) .add("outputLayout", outputLayout) .add("hashChannel", hashColumn) - .add("replicateNulls", replicateNulls) + .add("replicateNullsAndAny", replicateNullsAndAny) .add("bucketToPartition", bucketToPartition) .toString(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java index 19f0b0f87a344..8a92461e3289a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/Plan.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableMap; import java.util.Map; @@ -25,14 +27,17 @@ public class Plan { private final PlanNode root; private final Map types; + private final Map planNodeCosts; - public Plan(PlanNode root, Map types) + public Plan(PlanNode root, Map types, Map planNodeCosts) { requireNonNull(root, "root is null"); requireNonNull(types, "types is null"); + requireNonNull(planNodeCosts, "planNodeCosts is null"); this.root = root; this.types = ImmutableMap.copyOf(types); + this.planNodeCosts = planNodeCosts; } public PlanNode getRoot() @@ -44,4 +49,9 @@ public Map getTypes() { return types; } + + public Map getPlanNodeCosts() + { + return planNodeCosts; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java index da50ce11b6218..62afb0cc324aa 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java @@ -17,6 +17,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.metadata.TableLayout.NodePartitioning; +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; @@ -55,6 +56,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; /** @@ -255,18 +257,46 @@ public FragmentProperties setSingleNodeDistribution() public FragmentProperties setDistribution(PartitioningHandle distribution) { - if (partitioningHandle.isPresent() && !partitioningHandle.get().equals(distribution) && !partitioningHandle.get().equals(SOURCE_DISTRIBUTION)) { - checkState(partitioningHandle.get().isSingleNode(), - "Cannot set distribution to %s. Already set to %s", - distribution, - partitioningHandle); + if (partitioningHandle.isPresent()) { + chooseDistribution(distribution); return this; } - partitioningHandle = Optional.of(distribution); + partitioningHandle = Optional.of(distribution); return this; } + private void chooseDistribution(PartitioningHandle distribution) + { + checkState(partitioningHandle.isPresent(), "No partitioning to choose from"); + + if (partitioningHandle.get().equals(distribution) || + partitioningHandle.get().isSingleNode() || + isCompatibleSystemPartitioning(distribution)) { + return; + } + if (partitioningHandle.get().equals(SOURCE_DISTRIBUTION)) { + partitioningHandle = Optional.of(distribution); + return; + } + throw new IllegalStateException(format( + "Cannot set distribution to %s. Already set to %s", + distribution, + partitioningHandle)); + } + + private boolean isCompatibleSystemPartitioning(PartitioningHandle distribution) + { + ConnectorPartitioningHandle currentHandle = partitioningHandle.get().getConnectorHandle(); + ConnectorPartitioningHandle distributionHandle = distribution.getConnectorHandle(); + if ((currentHandle instanceof SystemPartitioningHandle) && + (distributionHandle instanceof SystemPartitioningHandle)) { + return ((SystemPartitioningHandle) currentHandle).getPartitioning() == + ((SystemPartitioningHandle) distributionHandle).getPartitioning(); + } + return false; + } + public FragmentProperties setCoordinatorOnlyDistribution() { if (partitioningHandle.isPresent() && partitioningHandle.get().isCoordinatorOnly()) { @@ -333,7 +363,7 @@ public Set getPartitionedSources() } private static class SchedulingOrderVisitor - extends PlanVisitor, Void> + extends PlanVisitor> { public List getSchedulingOrder(PlanNode node) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index dca4f4fc562d5..7617a2a419f74 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -19,41 +19,52 @@ import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.AddIntermediateAggregations; +import com.facebook.presto.sql.planner.iterative.rule.CreatePartialTopN; +import com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins; import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroLimit; import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroSample; import com.facebook.presto.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations; import com.facebook.presto.sql.planner.iterative.rule.InlineProjections; +import com.facebook.presto.sql.planner.iterative.rule.MergeAdjacentWindows; import com.facebook.presto.sql.planner.iterative.rule.MergeFilters; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithDistinct; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithTopN; import com.facebook.presto.sql.planner.iterative.rule.MergeLimits; +import com.facebook.presto.sql.planner.iterative.rule.PruneCrossJoinColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneJoinChildrenColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneJoinColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneMarkDistinctColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinFilteringSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTableScanColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneValuesColumns; +import com.facebook.presto.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughProject; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughSemiJoin; +import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughExchange; +import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion; +import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant; import com.facebook.presto.sql.planner.iterative.rule.SingleMarkDistinctToGroupBy; -import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsByPartitionsOrder; -import com.facebook.presto.sql.planner.iterative.rule.TransformExistsApplyToScalarApply; +import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsBySpecifications; +import com.facebook.presto.sql.planner.iterative.rule.TransformCorrelatedInPredicateToJoin; +import com.facebook.presto.sql.planner.iterative.rule.TransformExistsApplyToLateralNode; import com.facebook.presto.sql.planner.optimizations.AddExchanges; import com.facebook.presto.sql.planner.optimizations.AddLocalExchanges; import com.facebook.presto.sql.planner.optimizations.BeginTableWrite; import com.facebook.presto.sql.planner.optimizations.CanonicalizeExpressions; import com.facebook.presto.sql.planner.optimizations.DesugaringOptimizer; import com.facebook.presto.sql.planner.optimizations.DetermineJoinDistributionType; -import com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins; import com.facebook.presto.sql.planner.optimizations.HashGenerationOptimizer; import com.facebook.presto.sql.planner.optimizations.ImplementIntersectAndExceptAsUnion; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.optimizations.LimitPushDown; -import com.facebook.presto.sql.planner.optimizations.MergeProjections; -import com.facebook.presto.sql.planner.optimizations.MergeWindows; import com.facebook.presto.sql.planner.optimizations.MetadataDeleteOptimizer; import com.facebook.presto.sql.planner.optimizations.MetadataQueryOptimizer; import com.facebook.presto.sql.planner.optimizations.OptimizeMixedDistinctAggregations; @@ -64,13 +75,15 @@ import com.facebook.presto.sql.planner.optimizations.ProjectionPushDown; import com.facebook.presto.sql.planner.optimizations.PruneUnreferencedOutputs; import com.facebook.presto.sql.planner.optimizations.PushTableWriteThroughUnion; -import com.facebook.presto.sql.planner.optimizations.RemoveUnreferencedScalarInputApplyNodes; +import com.facebook.presto.sql.planner.optimizations.RemoveUnreferencedScalarLateralNodes; import com.facebook.presto.sql.planner.optimizations.SetFlatteningOptimizer; import com.facebook.presto.sql.planner.optimizations.SimplifyExpressions; +import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedNoAggregationSubqueryToJoin; import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedScalarAggregationToJoin; -import com.facebook.presto.sql.planner.optimizations.TransformQuantifiedComparisonApplyToScalarApply; +import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedSingleRowSubqueryToProject; +import com.facebook.presto.sql.planner.optimizations.TransformQuantifiedComparisonApplyToLateralJoin; import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedInPredicateSubqueryToSemiJoin; -import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedScalarToJoin; +import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedLateralToJoin; import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences; import com.facebook.presto.sql.planner.optimizations.WindowFilterPushDown; import com.google.common.collect.ImmutableList; @@ -116,13 +129,30 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea Set predicatePushDownRules = ImmutableSet.of( new MergeFilters()); + // TODO: Once we've migrated handling all the plan node types, replace uses of PruneUnreferencedOutputs with an IterativeOptimizer containing these rules. + Set columnPruningRules = ImmutableSet.of( + new PruneCrossJoinColumns(), + new PruneJoinChildrenColumns(), + new PruneJoinColumns(), + new PruneMarkDistinctColumns(), + new PruneSemiJoinColumns(), + new PruneSemiJoinFilteringSourceColumns(), + new PruneValuesColumns(), + new PruneTableScanColumns()); + IterativeOptimizer inlineProjections = new IterativeOptimizer( stats, - ImmutableList.of(new MergeProjections()), ImmutableSet.of( new InlineProjections(), new RemoveRedundantIdentityProjections())); + IterativeOptimizer projectionPushDown = new IterativeOptimizer( + stats, + ImmutableList.of(new ProjectionPushDown()), + ImmutableSet.of( + new PushProjectionThroughUnion(), + new PushProjectionThroughExchange())); + builder.add( new DesugaringOptimizer(metadata, sqlParser), // Clean up all the sugar in expressions, e.g. AtTimeZone, must be run before all the other optimizers new CanonicalizeExpressions(), @@ -130,6 +160,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea stats, ImmutableSet.builder() .addAll(predicatePushDownRules) + .addAll(columnPruningRules) .addAll(ImmutableSet.of( new RemoveRedundantIdentityProjections(), new RemoveFullSample(), @@ -141,12 +172,9 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea new MergeLimitWithTopN(), new PushLimitThroughMarkDistinct(), new PushLimitThroughSemiJoin(), - new MergeLimitWithDistinct(), - - new PruneValuesColumns(), - new PruneTableScanColumns())) + new MergeLimitWithDistinct())) .build() - ), + ), new IterativeOptimizer( stats, ImmutableSet.of( @@ -165,16 +193,42 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea inlineProjections, new IterativeOptimizer( stats, - ImmutableSet.of(new TransformExistsApplyToScalarApply(metadata.getFunctionRegistry()))), - new TransformQuantifiedComparisonApplyToScalarApply(metadata), - new RemoveUnreferencedScalarInputApplyNodes(), - new TransformUncorrelatedInPredicateSubqueryToSemiJoin(), - new TransformUncorrelatedScalarToJoin(), - new TransformCorrelatedScalarAggregationToJoin(metadata), + ImmutableSet.of(new TransformExistsApplyToLateralNode(metadata.getFunctionRegistry()))), + new TransformQuantifiedComparisonApplyToLateralJoin(metadata), + new IterativeOptimizer(stats, + ImmutableList.of( + new RemoveUnreferencedScalarLateralNodes(), + new TransformUncorrelatedLateralToJoin(), + new TransformUncorrelatedInPredicateSubqueryToSemiJoin()), + ImmutableSet.of( + new com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes(), + new com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedLateralToJoin(), + new com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin() + ) + ), + new IterativeOptimizer( + stats, + ImmutableList.of(new TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry())), + ImmutableSet.of(new com.facebook.presto.sql.planner.iterative.rule.TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionRegistry()))), + new IterativeOptimizer( + stats, + ImmutableSet.of( + new TransformCorrelatedInPredicateToJoin(), // must be run after PruneUnreferencedOutputs + new ImplementFilteredAggregations()) + ), + new TransformCorrelatedNoAggregationSubqueryToJoin(), + new TransformCorrelatedSingleRowSubqueryToProject(), new PredicatePushDown(metadata, sqlParser), + new PruneUnreferencedOutputs(), + new IterativeOptimizer( + stats, + ImmutableSet.of( + new RemoveRedundantIdentityProjections(), + new PushAggregationThroughOuterJoin()) + ), inlineProjections, new SimplifyExpressions(metadata, sqlParser), // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations - new ProjectionPushDown(), + projectionPushDown, new UnaliasSymbolReferences(), // Run again because predicate pushdown and projection pushdown might add more projections new PruneUnreferencedOutputs(), // Make sure to run this before index join. Filtered projections may not have all the columns. new IndexJoinOptimizer(metadata), // Run this after projections and filters have been fully simplified and pushed down @@ -182,13 +236,13 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea stats, ImmutableSet.of(new SimplifyCountOverConstant())), new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits - new MergeWindows(), new IterativeOptimizer( stats, ImmutableSet.of( // add UnaliasSymbolReferences when it's ported new RemoveRedundantIdentityProjections(), - new SwapAdjacentWindowsByPartitionsOrder())), + new SwapAdjacentWindowsBySpecifications(), + new MergeAdjacentWindows())), inlineProjections, new PruneUnreferencedOutputs(), // Make sure to run this at the end to help clean the plan for logging/execution and not remove info that other optimizers might need at an earlier point new IterativeOptimizer( @@ -196,9 +250,13 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea ImmutableSet.of(new RemoveRedundantIdentityProjections()) ), new MetadataQueryOptimizer(metadata), - new EliminateCrossJoins(), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again + new IterativeOptimizer( + stats, + ImmutableList.of(new com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins()), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again + ImmutableSet.of(new EliminateCrossJoins()) + ), new PredicatePushDown(metadata, sqlParser), - new ProjectionPushDown()); + projectionPushDown); if (featuresConfig.isOptimizeSingleDistinct()) { builder.add( @@ -209,6 +267,11 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea } builder.add(new OptimizeMixedDistinctAggregations(metadata)); + builder.add(new IterativeOptimizer( + stats, + ImmutableSet.of( + new CreatePartialTopN(), + new PushTopNThroughUnion()))); if (!forceSingleNode) { builder.add(new DetermineJoinDistributionType()); // Must run before AddExchanges @@ -225,7 +288,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea )); builder.add(new PredicatePushDown(metadata, sqlParser)); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate - builder.add(new ProjectionPushDown()); + builder.add(projectionPushDown); builder.add(inlineProjections); builder.add(new UnaliasSymbolReferences()); // Run unalias after merging projections to simplify projections more efficiently builder.add(new PruneUnreferencedOutputs()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index ec15c3a46b193..bfa55f8b40c9c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.block.SortOrder; @@ -28,6 +27,7 @@ import com.facebook.presto.sql.analyzer.RelationType; import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.FilterNode; @@ -49,8 +49,10 @@ import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; @@ -60,7 +62,6 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -69,7 +70,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -77,7 +77,6 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.StreamSupport; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -88,6 +87,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Streams.stream; import static java.util.Objects.requireNonNull; class QueryPlanner @@ -95,7 +95,7 @@ class QueryPlanner private final Analysis analysis; private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; - private final IdentityLinkedHashMap lambdaDeclarationToSymbolMap; + private final Map, Symbol> lambdaDeclarationToSymbolMap; private final Metadata metadata; private final Session session; private final SubqueryPlanner subqueryPlanner; @@ -104,7 +104,7 @@ class QueryPlanner Analysis analysis, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - IdentityLinkedHashMap lambdaDeclarationToSymbolMap, + Map, Symbol> lambdaDeclarationToSymbolMap, Metadata metadata, Session session) { @@ -519,8 +519,9 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) aggregationTranslations.copyMappingsFrom(groupingTranslations); // 2.d. Rewrite aggregates - ImmutableMap.Builder aggregationAssignments = ImmutableMap.builder(); - ImmutableMap.Builder functions = ImmutableMap.builder(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + // Map from aggregate function arguments to marker symbols, so that we can reuse the markers, if two aggregates have the same argument + Map, Symbol> argumentMarkers = new HashMap<>(); boolean needPostProjectionCoercion = false; for (FunctionCall aggregate : analysis.getAggregates(node)) { Expression parametersReplaced = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), aggregate); @@ -534,34 +535,27 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) rewritten = ((Cast) rewritten).getExpression(); needPostProjectionCoercion = true; } - aggregationAssignments.put(newSymbol, (FunctionCall) rewritten); aggregationTranslations.put(parametersReplaced, newSymbol); - functions.put(newSymbol, analysis.getFunctionSignature(aggregate)); - } - - // 2.e. Mark distinct rows for each aggregate that has DISTINCT - // Map from aggregate function arguments to marker symbols, so that we can reuse the markers, if two aggregates have the same argument - Map, Symbol> argumentMarkers = new HashMap<>(); - // Map from aggregate functions to marker symbols - Map masks = new HashMap<>(); - for (FunctionCall aggregate : Iterables.filter(analysis.getAggregates(node), FunctionCall::isDistinct)) { - Set args = ImmutableSet.copyOf(aggregate.getArguments()); - Symbol marker = argumentMarkers.get(args); - Symbol aggregateSymbol = aggregationTranslations.get(aggregate); - if (marker == null) { - if (args.size() == 1) { - marker = symbolAllocator.newSymbol(getOnlyElement(args), BOOLEAN, "distinct"); - } - else { - marker = symbolAllocator.newSymbol(aggregateSymbol.getName(), BOOLEAN, "distinct"); + Optional marker = Optional.empty(); + if (aggregate.isDistinct()) { + Set args = ImmutableSet.copyOf(aggregate.getArguments()); + marker = Optional.ofNullable(argumentMarkers.get(args)); + Symbol aggregateSymbol = aggregationTranslations.get(aggregate); + if (!marker.isPresent()) { + if (args.size() == 1) { + marker = Optional.of(symbolAllocator.newSymbol(getOnlyElement(args), BOOLEAN, "distinct")); + } + else { + marker = Optional.of(symbolAllocator.newSymbol(aggregateSymbol.getName(), BOOLEAN, "distinct")); + } + argumentMarkers.put(args, marker.get()); } - argumentMarkers.put(args, marker); } - - masks.put(aggregateSymbol, marker); + aggregations.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionSignature(aggregate), marker)); } + // 2.e. Mark distinct rows for each aggregate that has DISTINCT for (Map.Entry, Symbol> entry : argumentMarkers.entrySet()) { ImmutableList.Builder builder = ImmutableList.builder(); builder.addAll(groupingSymbols.stream() @@ -585,9 +579,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) AggregationNode aggregationNode = new AggregationNode( idAllocator.getNextId(), subPlan.getRoot(), - aggregationAssignments.build(), - functions.build(), - masks, + aggregations.build(), groupingSymbols, AggregationNode.Step.SINGLE, Optional.empty(), @@ -601,7 +593,39 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) if (needPostProjectionCoercion) { return explicitCoercionFields(subPlan, distinctGroupingColumns, analysis.getAggregates(node)); } - return subPlan; + + // 4. Project and re-write all grouping functions + return handleGroupingOperations(subPlan, node, groupIdSymbol); + } + + private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecification node, Optional groupIdSymbol) + { + if (analysis.getGroupingOperations(node).isEmpty()) { + return subPlan; + } + + TranslationMap newTranslations = subPlan.copyTranslations(); + + Assignments.Builder projections = Assignments.builder(); + projections.putIdentities(subPlan.getRoot().getOutputSymbols()); + + for (GroupingOperation groupingOperation : analysis.getGroupingOperations(node)) { + Expression rewritten = GroupingOperationRewriter.rewriteGroupingOperation(groupingOperation, node, analysis, metadata, groupIdSymbol); + Type coercion = analysis.getCoercion(groupingOperation); + Symbol symbol = symbolAllocator.newSymbol(rewritten, analysis.getTypeWithCoercions(groupingOperation)); + if (coercion != null) { + rewritten = new Cast( + rewritten, + coercion.getTypeSignature().toString(), + false, + metadata.getTypeManager().isTypeOnlyCoercion(analysis.getType(groupingOperation), coercion)); + } + projections.put(symbol, rewritten); + newTranslations.addIntermediateMapping(groupingOperation, rewritten); + newTranslations.put(rewritten, symbol); + } + + return new PlanBuilder(newTranslations, new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), projections.build()), analysis.getParameters()); } private PlanBuilder window(PlanBuilder subPlan, OrderBy node) @@ -758,8 +782,6 @@ private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node) idAllocator.getNextId(), subPlan.getRoot(), ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), ImmutableList.of(subPlan.getRoot().getOutputSymbols()), AggregationNode.Step.SINGLE, Optional.empty(), @@ -801,7 +823,7 @@ private PlanBuilder sort(PlanBuilder subPlan, Optional orderBy, Optiona PlanNode planNode; if (limit.isPresent() && !limit.get().equalsIgnoreCase("all")) { - planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), Long.parseLong(limit.get()), orderBySymbols.build(), orderings, false); + planNode = new TopNNode(idAllocator.getNextId(), subPlan.getRoot(), Long.parseLong(limit.get()), orderBySymbols.build(), orderings, TopNNode.Step.SINGLE); } else { planNode = new SortNode(idAllocator.getNextId(), subPlan.getRoot(), orderBySymbols.build(), orderings); @@ -841,9 +863,8 @@ private static List toSymbolReferences(List symbols) private static Map symbolsForExpressions(PlanBuilder builder, Iterable expressions) { - Set added = new HashSet<>(); - return StreamSupport.stream(expressions.spliterator(), false) - .filter(added::add) + return stream(expressions) + .distinct() .collect(toImmutableMap(expression -> expression, builder::translate)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 57b3fa0172240..6ce3c4ab570f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -18,6 +18,8 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.analyzer.Analysis; @@ -40,7 +42,6 @@ import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.Cast; -import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; @@ -52,7 +53,8 @@ import com.facebook.presto.sql.tree.Join; import com.facebook.presto.sql.tree.JoinUsing; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; -import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.Lateral; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; @@ -66,9 +68,6 @@ import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; import com.facebook.presto.sql.tree.Values; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; @@ -79,6 +78,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -97,7 +97,7 @@ class RelationPlanner private final Analysis analysis; private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; - private final IdentityLinkedHashMap lambdaDeclarationToSymbolMap; + private final Map, Symbol> lambdaDeclarationToSymbolMap; private final Metadata metadata; private final Session session; private final SubqueryPlanner subqueryPlanner; @@ -106,7 +106,7 @@ class RelationPlanner Analysis analysis, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - IdentityLinkedHashMap lambdaDeclarationToSymbolMap, + Map, Symbol> lambdaDeclarationToSymbolMap, Metadata metadata, Session session) { @@ -185,19 +185,20 @@ protected RelationPlan visitJoin(Join node, Void context) // TODO: translate the RIGHT join into a mirrored LEFT join when we refactor (@martint) RelationPlan leftPlan = process(node.getLeft(), context); - // Convert CROSS JOIN UNNEST to an UnnestNode - if (node.getRight() instanceof Unnest || (node.getRight() instanceof AliasedRelation && ((AliasedRelation) node.getRight()).getRelation() instanceof Unnest)) { - Unnest unnest; - if (node.getRight() instanceof AliasedRelation) { - unnest = (Unnest) ((AliasedRelation) node.getRight()).getRelation(); - } - else { - unnest = (Unnest) node.getRight(); + Optional unnest = getUnnest(node.getRight()); + if (unnest.isPresent()) { + if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) { + throw notSupportedException(unnest.get(), "UNNEST on other than the right side of CROSS JOIN"); } + return planCrossJoinUnnest(leftPlan, node, unnest.get()); + } + + Optional lateral = getLateral(node.getRight()); + if (lateral.isPresent()) { if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) { - throw notSupportedException(unnest, "UNNEST on other than the right side of CROSS JOIN"); + throw notSupportedException(lateral.get(), "LATERAL on other than the right side of CROSS JOIN"); } - return planCrossJoinUnnest(leftPlan, node, unnest); + return planLateralJoin(node, leftPlan, lateral.get()); } RelationPlan rightPlan = process(node.getRight(), context); @@ -233,7 +234,7 @@ protected RelationPlan visitJoin(Join node, Void context) continue; } - Set dependencies = DependencyExtractor.extractNames(conjunct, analysis.getColumnReferences()); + Set dependencies = SymbolsExtractor.extractNames(conjunct, analysis.getColumnReferences()); boolean isJoinUsing = node.getCriteria().filter(JoinUsing.class::isInstance).isPresent(); if (!isJoinUsing && (dependencies.stream().allMatch(left::canResolve) || dependencies.stream().allMatch(right::canResolve))) { // If the conjunct can be evaluated entirely with the inputs on either side of the join, add @@ -247,8 +248,8 @@ else if (conjunct instanceof ComparisonExpression) { Expression firstExpression = ((ComparisonExpression) conjunct).getLeft(); Expression secondExpression = ((ComparisonExpression) conjunct).getRight(); ComparisonExpressionType comparisonType = ((ComparisonExpression) conjunct).getType(); - Set firstDependencies = DependencyExtractor.extractNames(firstExpression, analysis.getColumnReferences()); - Set secondDependencies = DependencyExtractor.extractNames(secondExpression, analysis.getColumnReferences()); + Set firstDependencies = SymbolsExtractor.extractNames(firstExpression, analysis.getColumnReferences()); + Set secondDependencies = SymbolsExtractor.extractNames(secondExpression, analysis.getColumnReferences()); if (firstDependencies.stream().allMatch(left::canResolve) && secondDependencies.stream().allMatch(right::canResolve)) { leftComparisonExpressions.add(firstExpression); @@ -365,6 +366,43 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende return new RelationPlan(root, analysis.getScope(node), outputSymbols); } + private Optional getUnnest(Relation relation) + { + if (relation instanceof AliasedRelation) { + return getUnnest(((AliasedRelation) relation).getRelation()); + } + if (relation instanceof Unnest) { + return Optional.of((Unnest) relation); + } + return Optional.empty(); + } + + private Optional getLateral(Relation relation) + { + if (relation instanceof AliasedRelation) { + return getLateral(((AliasedRelation) relation).getRelation()); + } + if (relation instanceof Lateral) { + return Optional.of((Lateral) relation); + } + return Optional.empty(); + } + + private RelationPlan planLateralJoin(Join join, RelationPlan leftPlan, Lateral lateral) + { + RelationPlan rightPlan = process(lateral.getQuery(), null); + PlanBuilder leftPlanBuilder = initializePlanBuilder(leftPlan); + PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan); + + PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(leftPlanBuilder, rightPlanBuilder, lateral.getQuery(), true); + + List outputSymbols = ImmutableList.builder() + .addAll(leftPlan.getRoot().getOutputSymbols()) + .addAll(rightPlan.getRoot().getOutputSymbols()) + .build(); + return new RelationPlan(planBuilder.getRoot(), analysis.getScope(join), outputSymbols); + } + private static boolean isEqualComparisonExpression(Expression conjunct) { return conjunct instanceof ComparisonExpression && ((ComparisonExpression) conjunct).getType() == ComparisonExpressionType.EQUAL; @@ -409,16 +447,6 @@ else if (type instanceof MapType) { return new RelationPlan(unnestNode, analysis.getScope(joinNode), unnestNode.getOutputSymbols()); } - private static Expression oneIfNull(Optional symbol) - { - if (symbol.isPresent()) { - return new CoalesceExpression(symbol.get().toSymbolReference(), new LongLiteral("1")); - } - else { - return new LongLiteral("1"); - } - } - @Override protected RelationPlan visitTableSubquery(TableSubquery node, Void context) { @@ -660,8 +688,6 @@ private PlanNode distinct(PlanNode node) return new AggregationNode(idAllocator.getNextId(), node, ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), ImmutableList.of(node.getOutputSymbols()), AggregationNode.Step.SINGLE, Optional.empty(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SimplePlanVisitor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SimplePlanVisitor.java index 0ef325258042a..28383544d6e56 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SimplePlanVisitor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SimplePlanVisitor.java @@ -17,7 +17,7 @@ import com.facebook.presto.sql.planner.plan.PlanVisitor; public class SimplePlanVisitor - extends PlanVisitor + extends PlanVisitor { @Override protected Void visitPlan(PlanNode node, C context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java index 2812022bf1b6e..0be166375ec05 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java @@ -18,20 +18,21 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.Node; -import com.google.common.collect.ImmutableSet; +import com.facebook.presto.sql.tree.SymbolReference; -import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; /** * Currently this class handles only simple expressions like: * - * A.a < B.x. + * A.a < B.x * * It could be extended to handle any expressions like: * @@ -49,9 +50,12 @@ public final class SortExpressionExtractor { private SortExpressionExtractor() {} - public static Optional extractSortExpression(Map buildLayout, Expression filter) + public static Optional extractSortExpression(Set buildSymbols, Expression filter) { - Set buildFields = ImmutableSet.copyOf(buildLayout.values()); + if (!DeterminismEvaluator.isDeterministic(filter)) { + return Optional.empty(); + } + if (filter instanceof ComparisonExpression) { ComparisonExpression comparison = (ComparisonExpression) filter; switch (comparison.getType()) { @@ -59,14 +63,14 @@ public static Optional extractSortExpression(Map sortChannel = asBuildFieldReference(buildFields, comparison.getRight()); - boolean hasBuildReferencesOnOtherSide = hasBuildFieldReference(buildFields, comparison.getLeft()); + Optional sortChannel = asBuildSymbolReference(buildSymbols, comparison.getRight()); + boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getLeft()); if (!sortChannel.isPresent()) { - sortChannel = asBuildFieldReference(buildFields, comparison.getLeft()); - hasBuildReferencesOnOtherSide = hasBuildFieldReference(buildFields, comparison.getRight()); + sortChannel = asBuildSymbolReference(buildSymbols, comparison.getLeft()); + hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getRight()); } if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { - return Optional.of(new SortExpression(sortChannel.get())); + return sortChannel.map(symbolReference -> (Expression) symbolReference); } return Optional.empty(); default: @@ -77,30 +81,32 @@ public static Optional extractSortExpression(Map asBuildFieldReference(Set buildLayout, Expression expression) + private static Optional asBuildSymbolReference(Set buildLayout, Expression expression) { - if (expression instanceof FieldReference) { - FieldReference field = (FieldReference) expression; - if (buildLayout.contains(field.getFieldIndex())) { - return Optional.of(field.getFieldIndex()); + if (expression instanceof SymbolReference) { + SymbolReference symbolReference = (SymbolReference) expression; + if (buildLayout.contains(new Symbol(symbolReference.getName()))) { + return Optional.of(symbolReference); } } return Optional.empty(); } - private static boolean hasBuildFieldReference(Set buildLayout, Expression expression) + private static boolean hasBuildSymbolReference(Set buildSymbols, Expression expression) { - return new BuildFieldReferenceFinder(buildLayout).process(expression); + return new BuildSymbolReferenceFinder(buildSymbols).process(expression); } - private static class BuildFieldReferenceFinder + private static class BuildSymbolReferenceFinder extends AstVisitor { - private final Set buildLayout; + private final Set buildSymbols; - public BuildFieldReferenceFinder(Set buildLayout) + public BuildSymbolReferenceFinder(Set buildSymbols) { - this.buildLayout = ImmutableSet.copyOf(requireNonNull(buildLayout, "buildLayout is null")); + this.buildSymbols = requireNonNull(buildSymbols, "buildSymbols is null").stream() + .map(Symbol::getName) + .collect(toImmutableSet()); } @Override @@ -115,9 +121,9 @@ protected Boolean visitNode(Node node, Void context) } @Override - protected Boolean visitFieldReference(FieldReference fieldReference, Void context) + protected Boolean visitSymbolReference(SymbolReference symbolReference, Void context) { - return buildLayout.contains(fieldReference.getFieldIndex()); + return buildSymbols.contains(symbolReference.getName()); } } @@ -160,5 +166,11 @@ public String toString() .add("channel", channel) .toString(); } + + public static SortExpression fromExpression(Expression expression) + { + checkState(expression instanceof FieldReference, "Unsupported expression type [%s]", expression); + return new SortExpression(((FieldReference) expression).getFieldIndex()); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index c363222508514..49294c146959d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -16,12 +16,12 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.analyzer.Analysis; -import com.facebook.presto.sql.planner.optimizations.Predicates; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; @@ -36,12 +36,14 @@ import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression.Quantifier; +import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.SymbolReference; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.util.MorePredicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -70,7 +72,7 @@ class SubqueryPlanner private final Analysis analysis; private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; - private final IdentityLinkedHashMap lambdaDeclarationToSymbolMap; + private final Map, Symbol> lambdaDeclarationToSymbolMap; private final Metadata metadata; private final Session session; private final List parameters; @@ -79,7 +81,7 @@ class SubqueryPlanner Analysis analysis, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - IdentityLinkedHashMap lambdaDeclarationToSymbolMap, + Map, Symbol> lambdaDeclarationToSymbolMap, Metadata metadata, Session session, List parameters) @@ -221,23 +223,35 @@ private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder subPlan, SubqueryE subqueryPlan = subqueryPlan.withNewRoot(new EnforceSingleRowNode(idAllocator.getNextId(), subqueryPlan.getRoot())); subqueryPlan = subqueryPlan.appendProjections(coercions, symbolAllocator, idAllocator); - Assignments.Builder subqueryAssignments = Assignments.builder(); Symbol uncoercedScalarSubquerySymbol = subqueryPlan.translate(uncoercedScalarSubquery); subPlan.getTranslations().put(uncoercedScalarSubquery, uncoercedScalarSubquerySymbol); - subqueryAssignments.put(uncoercedScalarSubquerySymbol, uncoercedScalarSubquerySymbol.toSymbolReference()); for (Expression coercion : coercions) { Symbol coercionSymbol = subqueryPlan.translate(coercion); subPlan.getTranslations().put(coercion, coercionSymbol); - subqueryAssignments.put(coercionSymbol, coercionSymbol.toSymbolReference()); } - return appendApplyNode( - subPlan, - scalarSubquery.getQuery(), - subqueryPlan, - subqueryAssignments.build(), - correlationAllowed); + return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed); + } + + public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed) + { + PlanNode subqueryNode = subqueryPlan.getRoot(); + Map correlation = extractCorrelation(subPlan, subqueryNode); + if (!correlationAllowed && !correlation.isEmpty()) { + throw notSupportedException(query, "Correlated subquery in given context"); + } + subqueryNode = replaceExpressionsWithSymbols(subqueryNode, correlation); + + return new PlanBuilder( + subPlan.copyTranslations(), + new LateralJoinNode( + idAllocator.getNextId(), + subPlan.getRoot(), + subqueryNode, + ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values())), + LateralJoinNode.Type.INNER), + analysis.getParameters()); } private PlanBuilder appendExistsSubqueryApplyNodes(PlanBuilder builder, Set existsPredicates, boolean correlationAllowed) @@ -385,7 +399,7 @@ private PlanBuilder planQuantifiedApplyNode(PlanBuilder subPlan, QuantifiedCompa private static boolean isAggregationWithEmptyGroupBy(PlanNode planNode) { return searchFrom(planNode) - .skipOnlyWhen(Predicates.isInstanceOfAny(ProjectNode.class)) + .skipOnlyWhen(MorePredicates.isInstanceOfAny(ProjectNode.class)) .where(AggregationNode.class::isInstance) .findFirst() .map(AggregationNode.class::cast) @@ -406,6 +420,7 @@ private SubqueryExpression uncoercedSubquery(SubqueryExpression subquery) private List coercionsFor(Expression expression) { return analysis.getCoercions().keySet().stream() + .map(NodeRef::getNode) .filter(coercionExpression -> coercionExpression.equals(expression)) .collect(toImmutableList()); } @@ -432,7 +447,7 @@ private PlanBuilder appendApplyNode( root, subqueryNode, subqueryAssignments, - ImmutableList.copyOf(DependencyExtractor.extractUnique(correlation.values()))), + ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values()))), analysis.getParameters()); } @@ -492,7 +507,7 @@ private Set extractOuterColumnReferences(PlanNode planNode) .collect(toImmutableSet()); } - private static Set extractColumnReferences(Expression expression, Set columnReferences) + private static Set extractColumnReferences(Expression expression, Set> columnReferences) { ImmutableSet.Builder expressionColumnReferences = ImmutableSet.builder(); new ColumnReferencesExtractor(columnReferences).process(expression, expressionColumnReferences); @@ -511,9 +526,9 @@ private PlanNode replaceExpressionsWithSymbols(PlanNode planNode, Map> { - private final Set columnReferences; + private final Set> columnReferences; - private ColumnReferencesExtractor(Set columnReferences) + private ColumnReferencesExtractor(Set> columnReferences) { this.columnReferences = requireNonNull(columnReferences, "columnReferences is null"); } @@ -521,7 +536,7 @@ private ColumnReferencesExtractor(Set columnReferences) @Override protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableSet.Builder builder) { - if (columnReferences.contains(node)) { + if (columnReferences.contains(NodeRef.of(node))) { builder.add(node); } else { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java index dfa3e441c4bc0..6f4a6554ec768 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java @@ -18,6 +18,7 @@ import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.primitives.Ints; @@ -56,6 +57,7 @@ public Symbol newHashSymbol() public Symbol newSymbol(String nameHint, Type type, String suffix) { requireNonNull(nameHint, "name is null"); + requireNonNull(type, "type is null"); // TODO: workaround for the fact that QualifiedName lowercases parts nameHint = nameHint.toLowerCase(ENGLISH); @@ -104,6 +106,9 @@ else if (expression instanceof FunctionCall) { else if (expression instanceof SymbolReference) { nameHint = ((SymbolReference) expression).getName(); } + else if (expression instanceof GroupingOperation) { + nameHint = "grouping"; + } return newSymbol(nameHint, type, suffix); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java similarity index 76% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java index 75f7b4ee15d77..114b30881d1f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DependencyExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java @@ -13,12 +13,14 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; @@ -28,11 +30,12 @@ import java.util.Set; import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressions; +import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressionsNonRecursive; import static java.util.Objects.requireNonNull; -public final class DependencyExtractor +public final class SymbolsExtractor { - private DependencyExtractor() {} + private SymbolsExtractor() {} public static Set extractUnique(PlanNode node) { @@ -42,6 +45,22 @@ public static Set extractUnique(PlanNode node) return uniqueSymbols.build(); } + public static Set extractUniqueNonRecursive(PlanNode node) + { + ImmutableSet.Builder uniqueSymbols = ImmutableSet.builder(); + extractExpressionsNonRecursive(node).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression))); + + return uniqueSymbols.build(); + } + + public static Set extractUnique(PlanNode node, Lookup lookup) + { + ImmutableSet.Builder uniqueSymbols = ImmutableSet.builder(); + extractExpressions(node, lookup).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression))); + + return uniqueSymbols.build(); + } + public static Set extractUnique(Expression expression) { return ImmutableSet.copyOf(extractAll(expression)); @@ -64,7 +83,7 @@ public static List extractAll(Expression expression) } // to extract qualified name with prefix - public static Set extractNames(Expression expression, Set columnReferences) + public static Set extractNames(Expression expression, Set> columnReferences) { ImmutableSet.Builder builder = ImmutableSet.builder(); new QualifiedNameBuilderVisitor(columnReferences).process(expression, builder); @@ -85,9 +104,9 @@ protected Void visitSymbolReference(SymbolReference node, ImmutableList.Builder< private static class QualifiedNameBuilderVisitor extends DefaultTraversalVisitor> { - private final Set columnReferences; + private final Set> columnReferences; - private QualifiedNameBuilderVisitor(Set columnReferences) + private QualifiedNameBuilderVisitor(Set> columnReferences) { this.columnReferences = requireNonNull(columnReferences, "columnReferences is null"); } @@ -95,7 +114,7 @@ private QualifiedNameBuilderVisitor(Set columnReferences) @Override protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableSet.Builder builder) { - if (columnReferences.contains(node)) { + if (columnReferences.contains(NodeRef.of(node))) { builder.add(DereferenceExpression.getQualifiedName(node)); } else { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java index 490b6d2a41e1d..4dcafe075ae5e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java @@ -25,7 +25,7 @@ import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.ImmutableList; import java.util.HashMap; @@ -45,7 +45,7 @@ class TranslationMap // all expressions are rewritten in terms of fields declared by this relation plan private final RelationPlan rewriteBase; private final Analysis analysis; - private final IdentityLinkedHashMap lambdaDeclarationToSymbolMap; + private final Map, Symbol> lambdaDeclarationToSymbolMap; // current mappings of underlying field -> symbol for translating direct field references private final Symbol[] fieldSymbols; @@ -54,7 +54,7 @@ class TranslationMap private final Map expressionToSymbols = new HashMap<>(); private final Map expressionToExpressions = new HashMap<>(); - public TranslationMap(RelationPlan rewriteBase, Analysis analysis, IdentityLinkedHashMap lambdaDeclarationToSymbolMap) + public TranslationMap(RelationPlan rewriteBase, Analysis analysis, Map, Symbol> lambdaDeclarationToSymbolMap) { this.rewriteBase = requireNonNull(rewriteBase, "rewriteBase is null"); this.analysis = requireNonNull(analysis, "analysis is null"); @@ -73,7 +73,7 @@ public Analysis getAnalysis() return analysis; } - public IdentityLinkedHashMap getLambdaDeclarationToSymbolMap() + public Map, Symbol> getLambdaDeclarationToSymbolMap() { return lambdaDeclarationToSymbolMap; } @@ -240,7 +240,7 @@ public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTre { LambdaArgumentDeclaration referencedLambdaArgumentDeclaration = analysis.getLambdaArgumentReference(node); if (referencedLambdaArgumentDeclaration != null) { - Symbol symbol = lambdaDeclarationToSymbolMap.get(referencedLambdaArgumentDeclaration); + Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(referencedLambdaArgumentDeclaration)); return coerceIfNecessary(node, symbol.toSymbolReference()); } else { @@ -278,7 +278,8 @@ public Expression rewriteLambdaExpression(LambdaExpression node, Void context, E ImmutableList.Builder newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { - newArguments.add(new LambdaArgumentDeclaration(lambdaDeclarationToSymbolMap.get(argument).getName())); + Symbol symbol = lambdaDeclarationToSymbolMap.get(NodeRef.of(argument)); + newArguments.add(new LambdaArgumentDeclaration(symbol.getName())); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java index 681e61a893c70..79a8ca719e87d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/GroupReference.java @@ -16,6 +16,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.google.common.collect.ImmutableList; import java.util.List; @@ -46,6 +47,12 @@ public List getSources() return ImmutableList.of(); } + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitGroupReference(this, context); + } + @Override public List getOutputSymbols() { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java index 136626981915b..612836aee46de 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java @@ -24,9 +24,9 @@ import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import io.airlift.units.Duration; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -41,7 +41,7 @@ public class IterativeOptimizer implements PlanOptimizer { private final List legacyRules; - private final Set rules; + private final RuleStore ruleStore; private final StatsRecorder stats; public IterativeOptimizer(StatsRecorder stats, Set rules) @@ -52,10 +52,13 @@ public IterativeOptimizer(StatsRecorder stats, Set rules) public IterativeOptimizer(StatsRecorder stats, List legacyRules, Set newRules) { this.legacyRules = ImmutableList.copyOf(legacyRules); - this.rules = ImmutableSet.copyOf(newRules); + this.ruleStore = RuleStore.builder() + .register(newRules) + .build(); + this.stats = stats; - stats.registerAll(rules); + stats.registerAll(newRules); } @Override @@ -71,14 +74,7 @@ public PlanNode optimize(PlanNode plan, Session session, Map types } Memo memo = new Memo(idAllocator, plan); - - Lookup lookup = node -> { - if (node instanceof GroupReference) { - return memo.getNode(((GroupReference) node).getGroupId()); - } - - return node; - }; + Lookup lookup = Lookup.from(memo::resolve); Duration timeout = SystemSessionProperties.getOptimizerTimeout(session); exploreGroup(memo.getRootGroup(), new Context(memo, lookup, idAllocator, symbolAllocator, System.nanoTime(), timeout.toMillis(), session)); @@ -119,8 +115,15 @@ private boolean exploreNode(int group, Context context) } done = true; - for (Rule rule : rules) { + Iterator possiblyMatchingRules = ruleStore.getCandidates(node).iterator(); + while (possiblyMatchingRules.hasNext()) { + Rule rule = possiblyMatchingRules.next(); Optional transformed; + + if (!rule.getPattern().matches(node)) { + continue; + } + long duration; try { long start = System.nanoTime(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java index 57b1e3e94a4f0..7305e501e05b1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java @@ -15,6 +15,10 @@ import com.facebook.presto.sql.planner.plan.PlanNode; +import java.util.function.Function; + +import static com.google.common.base.Verify.verify; + public interface Lookup { /** @@ -25,4 +29,27 @@ public interface Lookup * argument as is. */ PlanNode resolve(PlanNode node); + + /** + * A Lookup implementation that does not perform lookup. It satisfies contract + * by rejecting {@link GroupReference}-s. + */ + static Lookup noLookup() + { + return node -> { + verify(!(node instanceof GroupReference), "Unexpected GroupReference"); + return node; + }; + } + + static Lookup from(Function resolver) + { + return node -> { + if (node instanceof GroupReference) { + return resolver.apply((GroupReference) node); + } + + return node; + }; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java index 1c816e1ec2034..8f62ea8d4fde0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java @@ -18,11 +18,11 @@ import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import static com.facebook.presto.sql.planner.iterative.Plans.resolveGroupReferences; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -79,6 +79,11 @@ public PlanNode getNode(int group) return membership.get(group); } + public PlanNode resolve(GroupReference groupReference) + { + return getNode(groupReference.getGroupId()); + } + public PlanNode extract() { return extract(getNode(rootGroup)); @@ -86,15 +91,7 @@ public PlanNode extract() private PlanNode extract(PlanNode node) { - if (node instanceof GroupReference) { - return extract(membership.get(((GroupReference) node).getGroupId())); - } - - List children = node.getSources().stream() - .map(this::extract) - .collect(Collectors.toList()); - - return node.replaceChildren(children); + return resolveGroupReferences(node, Lookup.from(this::resolve)); } public PlanNode replace(int group, PlanNode node, String reason) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java new file mode 100644 index 0000000000000..a9ceb13168720 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Pattern.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative; + +import com.facebook.presto.sql.planner.plan.PlanNode; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public abstract class Pattern +{ + private static final Pattern ANY_NODE = new MatchNodeClass(PlanNode.class); + + private Pattern() {} + + public abstract boolean matches(PlanNode node); + + public static Pattern any() + { + return ANY_NODE; + } + + public static Pattern node(Class nodeClass) + { + return new MatchNodeClass(nodeClass); + } + + static class MatchNodeClass + extends Pattern + { + private final Class nodeClass; + + MatchNodeClass(Class nodeClass) + { + this.nodeClass = requireNonNull(nodeClass, "nodeClass is null"); + } + + Class getNodeClass() + { + return nodeClass; + } + + @Override + public boolean matches(PlanNode node) + { + return nodeClass.isInstance(node); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("nodeClass", nodeClass) + .toString(); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Plans.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Plans.java new file mode 100644 index 0000000000000..5db2530bf2729 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Plans.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative; + +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanVisitor; + +import java.util.List; +import java.util.stream.Collectors; + +import static java.util.Objects.requireNonNull; + +public class Plans +{ + public static PlanNode resolveGroupReferences(PlanNode node, Lookup lookup) + { + requireNonNull(node, "node is null"); + return node.accept(new ResolvingVisitor(lookup), null); + } + + private static class ResolvingVisitor + extends PlanVisitor + { + private final Lookup lookup; + + public ResolvingVisitor(Lookup lookup) + { + this.lookup = requireNonNull(lookup, "lookup is null"); + } + + @Override + protected PlanNode visitPlan(PlanNode node, Void context) + { + List children = node.getSources().stream() + .map(child -> child.accept(this, context)) + .collect(Collectors.toList()); + + return node.replaceChildren(children); + } + + @Override + public PlanNode visitGroupReference(GroupReference node, Void context) + { + return lookup.resolve(node).accept(this, context); + } + } + + private Plans() {} +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java index ae55eb54aa3d5..4d2776bf068bf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Rule.java @@ -22,5 +22,15 @@ public interface Rule { + /** + * Returns a pattern to which plan nodes this rule applies. + * Notice that rule may be still invoked for plan nodes which given pattern does not apply, + * then rule should return Optional.empty() in such case + */ + default Pattern getPattern() + { + return Pattern.any(); + } + Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java new file mode 100644 index 0000000000000..f190a5656092d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/RuleStore.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative; + +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Streams; + +import java.util.Iterator; +import java.util.Set; +import java.util.stream.Stream; + +public class RuleStore +{ + private final ListMultimap, Rule> rulesByClass; + + private RuleStore(ListMultimap, Rule> rulesByClass) + { + this.rulesByClass = ImmutableListMultimap.copyOf(rulesByClass); + } + + public Stream getCandidates(PlanNode planNode) + { + return Streams.stream(ancestors(planNode.getClass())) + .flatMap(clazz -> rulesByClass.get(clazz).stream()); + } + + private static Iterator> ancestors(Class planNodeClass) + { + return new AbstractIterator>() { + private Class current = planNodeClass; + + @Override + protected Class computeNext() + { + if (!PlanNode.class.isAssignableFrom(current)) { + return endOfData(); + } + + Class result = (Class) current; + current = current.getSuperclass(); + + return result; + } + }; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private final ImmutableListMultimap.Builder, Rule> rulesByClass = ImmutableListMultimap.builder(); + + public Builder register(Set newRules) + { + newRules.forEach(this::register); + return this; + } + + public Builder register(Rule newRule) + { + Pattern pattern = newRule.getPattern(); + if (pattern instanceof Pattern.MatchNodeClass) { + rulesByClass.put(((Pattern.MatchNodeClass) pattern).getNodeClass(), newRule); + } + else { + throw new IllegalArgumentException("Unexpected Pattern: " + pattern); + } + return this; + } + + public RuleStore build() + { + return new RuleStore(rulesByClass.build()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java index 066f86e3175d8..89fdb4e3c0d16 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -15,13 +15,14 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -66,6 +67,14 @@ public class AddIntermediateAggregations implements Rule { + private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + @Override public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) { @@ -101,7 +110,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato source = new AggregationNode( idAllocator.getNextId(), source, - inputsAsOutputs(aggregation.getAssignments()), + inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), @@ -143,7 +152,7 @@ private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeI return new AggregationNode( idAllocator.getNextId(), gatheringExchange, - outputsAsInputs(aggregation.getAssignments()), + outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), @@ -186,7 +195,7 @@ private static Map inputsAsOutputs(Map builder = ImmutableMap.builder(); for (Map.Entry entry : assignments.entrySet()) { // Should only have one input symbol - Symbol input = getOnlyElement(DependencyExtractor.extractAll(entry.getValue().getCall())); + Symbol input = getOnlyElement(SymbolsExtractor.extractAll(entry.getValue().getCall())); builder.put(input, entry.getValue()); } return builder.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java new file mode 100644 index 0000000000000..82364de8c05c4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CreatePartialTopN.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TopNNode; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.plan.TopNNode.Step.FINAL; +import static com.facebook.presto.sql.planner.plan.TopNNode.Step.PARTIAL; +import static com.facebook.presto.sql.planner.plan.TopNNode.Step.SINGLE; + +public class CreatePartialTopN + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(TopNNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof TopNNode)) { + return Optional.empty(); + } + + TopNNode single = (TopNNode) node; + + if (!single.getStep().equals(SINGLE)) { + return Optional.empty(); + } + + PlanNode source = lookup.resolve(single.getSource()); + + TopNNode partial = new TopNNode( + idAllocator.getNextId(), + source, + single.getCount(), + single.getOrderBy(), + single.getOrderings(), + PARTIAL); + + return Optional.of(new TopNNode( + idAllocator.getNextId(), + partial, + single.getCount(), + single.getOrderBy(), + single.getOrderings(), + FINAL)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java new file mode 100644 index 0000000000000..744925d4f4adb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -0,0 +1,209 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.Set; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class EliminateCrossJoins + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(JoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof JoinNode)) { + return Optional.empty(); + } + + if (!SystemSessionProperties.isJoinReorderingEnabled(session)) { + return Optional.empty(); + } + + JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, lookup); + if (joinGraph.size() < 3) { + return Optional.empty(); + } + + List joinOrder = getJoinOrder(joinGraph); + if (isOriginalOrder(joinOrder)) { + return Optional.empty(); + } + + PlanNode replacement = buildJoinTree(node.getOutputSymbols(), joinGraph, joinOrder, idAllocator); + return Optional.of(replacement); + } + + public static boolean isOriginalOrder(List joinOrder) + { + for (int i = 0; i < joinOrder.size(); i++) { + if (joinOrder.get(i) != i) { + return false; + } + } + return true; + } + + /** + * Given JoinGraph determine the order of joins between graph nodes + * by traversing JoinGraph. Any graph traversal algorithm could be used + * here (like BFS or DFS), but we use PriorityQueue to preserve + * original JoinOrder as mush as it is possible. PriorityQueue returns + * next nodes to join in order of their occurrence in original Plan. + */ + public static List getJoinOrder(JoinGraph graph) + { + ImmutableList.Builder joinOrder = ImmutableList.builder(); + + Map priorities = new HashMap<>(); + for (int i = 0; i < graph.size(); i++) { + priorities.put(graph.getNode(i).getId(), i); + } + + PriorityQueue nodesToVisit = new PriorityQueue<>( + graph.size(), + (Comparator) (node1, node2) -> priorities.get(node1.getId()).compareTo(priorities.get(node2.getId()))); + Set visited = new HashSet<>(); + + nodesToVisit.add(graph.getNode(0)); + + while (!nodesToVisit.isEmpty()) { + PlanNode node = nodesToVisit.poll(); + if (!visited.contains(node)) { + visited.add(node); + joinOrder.add(node); + for (JoinGraph.Edge edge : graph.getEdges(node)) { + nodesToVisit.add(edge.getTargetNode()); + } + } + + if (nodesToVisit.isEmpty() && visited.size() < graph.size()) { + // disconnected graph, find new starting point + Optional firstNotVisitedNode = graph.getNodes().stream() + .filter(graphNode -> !visited.contains(graphNode)) + .findFirst(); + if (firstNotVisitedNode.isPresent()) { + nodesToVisit.add(firstNotVisitedNode.get()); + } + } + } + + checkState(visited.size() == graph.size()); + return joinOrder.build().stream() + .map(node -> priorities.get(node.getId())) + .collect(toImmutableList()); + } + + public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGraph graph, List joinOrder, PlanNodeIdAllocator idAllocator) + { + requireNonNull(expectedOutputSymbols, "expectedOutputSymbols is null"); + requireNonNull(idAllocator, "idAllocator is null"); + requireNonNull(graph, "graph is null"); + joinOrder = ImmutableList.copyOf(requireNonNull(joinOrder, "joinOrder is null")); + checkArgument(joinOrder.size() >= 2); + + PlanNode result = graph.getNode(joinOrder.get(0)); + Set alreadyJoinedNodes = new HashSet<>(); + alreadyJoinedNodes.add(result.getId()); + + for (int i = 1; i < joinOrder.size(); i++) { + PlanNode rightNode = graph.getNode(joinOrder.get(i)); + alreadyJoinedNodes.add(rightNode.getId()); + + ImmutableList.Builder criteria = ImmutableList.builder(); + + for (JoinGraph.Edge edge : graph.getEdges(rightNode)) { + PlanNode targetNode = edge.getTargetNode(); + if (alreadyJoinedNodes.contains(targetNode.getId())) { + criteria.add(new JoinNode.EquiJoinClause( + edge.getTargetSymbol(), + edge.getSourceSymbol())); + } + } + + result = new JoinNode( + idAllocator.getNextId(), + JoinNode.Type.INNER, + result, + rightNode, + criteria.build(), + ImmutableList.builder() + .addAll(result.getOutputSymbols()) + .addAll(rightNode.getOutputSymbols()) + .build(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + + List filters = graph.getFilters(); + + for (Expression filter : filters) { + result = new FilterNode( + idAllocator.getNextId(), + result, + filter); + } + + if (graph.getAssignments().isPresent()) { + result = new ProjectNode( + idAllocator.getNextId(), + result, + Assignments.copyOf(graph.getAssignments().get())); + } + + // If needed, introduce a projection to constrain the outputs to what was originally expected + // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) + return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputSymbols)).orElse(result); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java index dc84d809df378..b379303dfe779 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroLimit.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -28,13 +29,17 @@ public class EvaluateZeroLimit implements Rule { + private static final Pattern PATTERN = Pattern.node(LimitNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof LimitNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { LimitNode limit = (LimitNode) node; if (limit.getCount() != 0) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java index faafcd9462495..ecf5b95424fc8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EvaluateZeroSample.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SampleNode; @@ -31,13 +32,17 @@ public class EvaluateZeroSample implements Rule { + private static final Pattern PATTERN = Pattern.node(SampleNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof SampleNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { SampleNode sample = (SampleNode) node; if (sample.getSampleRatio() != 0) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java index 5258d424fa9a3..9705352c30c15 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementBernoulliSampleAsFilter.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -49,13 +50,17 @@ public class ImplementBernoulliSampleAsFilter implements Rule { + private static final Pattern PATTERN = Pattern.node(SampleNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof SampleNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { SampleNode sample = (SampleNode) node; if (sample.getSampleType() != SampleNode.Type.BERNOULLI) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 3493afb5e4e0d..e1e54119393ca 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -18,8 +18,10 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -34,16 +36,13 @@ /** * Implements filtered aggregations by transforming plans of the following shape: - * *
  * - Aggregation
  *        F1(...) FILTER (WHERE C1(...)),
  *        F2(...) FILTER (WHERE C2(...))
  *     - X
  * 
- * * into - * *
  * - Aggregation
  *        F1(...) mask ($0)
@@ -58,47 +57,48 @@
 public class ImplementFilteredAggregations
         implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(AggregationNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof AggregationNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         AggregationNode aggregation = (AggregationNode) node;
 
         boolean hasFilters = aggregation.getAggregations()
-                .entrySet().stream()
-                .anyMatch(e -> e.getValue().getFilter().isPresent() &&
-                        !aggregation.getMasks().containsKey(e.getKey())); // can't handle filtered aggregations with DISTINCT (conservatively, if they have a mask)
+                .values().stream()
+                .anyMatch(e -> e.getCall().getFilter().isPresent() &&
+                        !e.getMask().isPresent()); // can't handle filtered aggregations with DISTINCT (conservatively, if they have a mask)
 
         if (!hasFilters) {
             return Optional.empty();
         }
 
         Assignments.Builder newAssignments = Assignments.builder();
-        ImmutableMap.Builder masks = ImmutableMap.builder()
-                .putAll(aggregation.getMasks());
-        ImmutableMap.Builder calls = ImmutableMap.builder();
+        ImmutableMap.Builder aggregations = ImmutableMap.builder();
 
-        for (Map.Entry entry : aggregation.getAggregations().entrySet()) {
+        for (Map.Entry entry : aggregation.getAggregations().entrySet()) {
             Symbol output = entry.getKey();
 
             // strip the filters
-            FunctionCall call = entry.getValue();
-            calls.put(output, new FunctionCall(
-                    call.getName(),
-                    call.getWindow(),
-                    Optional.empty(),
-                    call.isDistinct(),
-                    call.getArguments()));
+            FunctionCall call = entry.getValue().getCall();
+            Optional mask = entry.getValue().getMask();
 
             if (call.getFilter().isPresent()) {
-                Expression filter = entry.getValue().getFilter().get();
+                Expression filter = call.getFilter().get();
                 Symbol symbol = symbolAllocator.newSymbol(filter, BOOLEAN);
                 newAssignments.put(symbol, filter);
-                masks.put(output, symbol);
+                mask = Optional.of(symbol);
             }
+            aggregations.put(output, new Aggregation(
+                    new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.isDistinct(), call.getArguments()),
+                    entry.getValue().getSignature(),
+                    mask));
         }
 
         // identity projection for all existing inputs
@@ -111,9 +111,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
                                 idAllocator.getNextId(),
                                 aggregation.getSource(),
                                 newAssignments.build()),
-                        calls.build(),
-                        aggregation.getFunctions(),
-                        masks.build(),
+                        aggregations.build(),
                         aggregation.getGroupingSets(),
                         aggregation.getStep(),
                         aggregation.getHashSymbol(),
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java
index a2f1c7e92ab92..8df1175dbfe9e 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java
@@ -14,12 +14,13 @@
 package com.facebook.presto.sql.planner.iterative.rule;
 
 import com.facebook.presto.Session;
-import com.facebook.presto.sql.planner.DependencyExtractor;
 import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
 import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
 import com.facebook.presto.sql.planner.Symbol;
 import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.SymbolsExtractor;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.Assignments;
 import com.facebook.presto.sql.planner.plan.PlanNode;
@@ -47,13 +48,17 @@
 public class InlineProjections
         implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(ProjectNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof ProjectNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         ProjectNode parent = (ProjectNode) node;
 
         PlanNode source = lookup.resolve(parent.getSource());
@@ -84,7 +89,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
                 .entrySet().stream()
                 .filter(entry -> targets.contains(entry.getKey()))
                 .map(Map.Entry::getValue)
-                .flatMap(entry -> DependencyExtractor.extractAll(entry).stream())
+                .flatMap(entry -> SymbolsExtractor.extractAll(entry).stream())
                 .collect(toSet());
 
         Assignments.Builder childAssignments = Assignments.builder();
@@ -132,7 +137,7 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN
 
         Map dependencies = parent.getAssignments()
                 .getExpressions().stream()
-                .flatMap(expression -> DependencyExtractor.extractAll(expression).stream())
+                .flatMap(expression -> SymbolsExtractor.extractAll(expression).stream())
                 .collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
 
         // find references to simple constants
@@ -162,7 +167,7 @@ private Set extractTryArguments(Expression expression)
         return AstUtils.preOrder(expression)
                 .filter(TryExpression.class::isInstance)
                 .map(TryExpression.class::cast)
-                .flatMap(tryExpression -> DependencyExtractor.extractAll(tryExpression).stream())
+                .flatMap(tryExpression -> SymbolsExtractor.extractAll(tryExpression).stream())
                 .collect(toSet());
     }
 }
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java
new file mode 100644
index 0000000000000..38d3724c176d3
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeAdjacentWindows.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.WindowNode;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.Optional;
+
+import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.dependsOn;
+
+public class MergeAdjacentWindows
+    implements Rule
+{
+    private static final Pattern PATTERN = Pattern.node(WindowNode.class);
+
+    @Override
+    public Pattern getPattern()
+    {
+        return PATTERN;
+    }
+
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
+        if (!(node instanceof WindowNode)) {
+            return Optional.empty();
+        }
+
+        WindowNode parent = (WindowNode) node;
+
+        PlanNode source = lookup.resolve(parent.getSource());
+        if (!(source instanceof WindowNode)) {
+            return Optional.empty();
+        }
+
+        WindowNode child = (WindowNode) source;
+
+        if (!child.getSpecification().equals(parent.getSpecification()) || dependsOn(parent, child)) {
+            return Optional.empty();
+        }
+
+        ImmutableMap.Builder functionsBuilder = ImmutableMap.builder();
+        functionsBuilder.putAll(parent.getWindowFunctions());
+        functionsBuilder.putAll(child.getWindowFunctions());
+
+        return Optional.of(new WindowNode(
+                parent.getId(),
+                child.getSource(),
+                parent.getSpecification(),
+                functionsBuilder.build(),
+                parent.getHashSymbol(),
+                parent.getPrePartitionedInputs(),
+                parent.getPreSortedOrderPrefix()));
+    }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java
index e00d6e1fba981..c70cbed39298d 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeFilters.java
@@ -17,6 +17,7 @@
 import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
 import com.facebook.presto.sql.planner.SymbolAllocator;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.FilterNode;
 import com.facebook.presto.sql.planner.plan.PlanNode;
@@ -28,13 +29,17 @@
 public class MergeFilters
     implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(FilterNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof FilterNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         FilterNode parent = (FilterNode) node;
 
         PlanNode source = lookup.resolve(parent.getSource());
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java
index 90867fd868630..616dd88cdad4b 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithDistinct.java
@@ -17,6 +17,7 @@
 import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
 import com.facebook.presto.sql.planner.SymbolAllocator;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.AggregationNode;
 import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
@@ -28,13 +29,17 @@
 public class MergeLimitWithDistinct
     implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(LimitNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof LimitNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         LimitNode parent = (LimitNode) node;
 
         PlanNode input = lookup.resolve(parent.getSource());
@@ -54,6 +59,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
                         child.getSource(),
                         parent.getCount(),
                         false,
+                        child.getGroupingKeys(),
                         child.getHashSymbol()));
     }
 
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java
index adf5f5c06a510..f6d466acc5644 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithSort.java
@@ -17,6 +17,7 @@
 import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
 import com.facebook.presto.sql.planner.SymbolAllocator;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.LimitNode;
 import com.facebook.presto.sql.planner.plan.PlanNode;
@@ -28,13 +29,17 @@
 public class MergeLimitWithSort
     implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(LimitNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof LimitNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         LimitNode parent = (LimitNode) node;
 
         PlanNode source = lookup.resolve(parent.getSource());
@@ -51,6 +56,6 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
                         parent.getCount(),
                         child.getOrderBy(),
                         child.getOrderings(),
-                        parent.isPartial()));
+                        parent.isPartial() ? TopNNode.Step.PARTIAL : TopNNode.Step.SINGLE));
     }
 }
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java
index fc369b50060af..cf91fbacfd17f 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimitWithTopN.java
@@ -17,6 +17,7 @@
 import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
 import com.facebook.presto.sql.planner.SymbolAllocator;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.LimitNode;
 import com.facebook.presto.sql.planner.plan.PlanNode;
@@ -27,13 +28,17 @@
 public class MergeLimitWithTopN
     implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(LimitNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof LimitNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         LimitNode parent = (LimitNode) node;
 
         PlanNode source = lookup.resolve(parent.getSource());
@@ -50,6 +55,6 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
                         Math.min(parent.getCount(), child.getCount()),
                         child.getOrderBy(),
                         child.getOrderings(),
-                        parent.isPartial()));
+                        parent.isPartial() ? TopNNode.Step.PARTIAL : TopNNode.Step.SINGLE));
     }
 }
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java
index 42f570d16986f..06e9ef9ccf65c 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeLimits.java
@@ -17,6 +17,7 @@
 import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
 import com.facebook.presto.sql.planner.SymbolAllocator;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.LimitNode;
 import com.facebook.presto.sql.planner.plan.PlanNode;
@@ -26,13 +27,17 @@
 public class MergeLimits
     implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(LimitNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof LimitNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         LimitNode parent = (LimitNode) node;
 
         PlanNode source = lookup.resolve(parent.getSource());
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java
new file mode 100644
index 0000000000000..a0aa799dbee89
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCrossJoinColumns.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.google.common.collect.ImmutableList;
+
+import java.util.Optional;
+
+import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs;
+import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs;
+
+/**
+ * Cross joins don't support output symbol selection, so push the project-off through the node.
+ */
+public class PruneCrossJoinColumns
+        implements Rule
+{
+    private static final Pattern PATTERN = Pattern.node(ProjectNode.class);
+
+    @Override
+    public Pattern getPattern()
+    {
+        return PATTERN;
+    }
+
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
+        ProjectNode parent = (ProjectNode) node;
+
+        PlanNode child = lookup.resolve(parent.getSource());
+        if (!(child instanceof JoinNode)) {
+            return Optional.empty();
+        }
+
+        JoinNode joinNode = (JoinNode) child;
+        if (!joinNode.isCrossJoin()) {
+            return Optional.empty();
+        }
+
+        return pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions())
+                .map(dependencies ->
+                        parent.replaceChildren(ImmutableList.of(
+                                restrictChildOutputs(idAllocator, joinNode, dependencies, dependencies).get())));
+    }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java
new file mode 100644
index 0000000000000..15fc5b8ee128f
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinChildrenColumns.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.SymbolsExtractor;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.Optional;
+import java.util.Set;
+
+import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs;
+
+/**
+ * Non-Cross joins support output symbol selection, so make any project-off of child columns explicit in project nodes.
+ */
+public class PruneJoinChildrenColumns
+        implements Rule
+{
+    private static final Pattern PATTERN = Pattern.node(JoinNode.class);
+
+    @Override
+    public Pattern getPattern()
+    {
+        return PATTERN;
+    }
+
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
+        JoinNode joinNode = (JoinNode) node;
+        if (joinNode.isCrossJoin()) {
+            return Optional.empty();
+        }
+
+        Set globallyUsableInputs = ImmutableSet.builder()
+                .addAll(joinNode.getOutputSymbols())
+                .addAll(
+                        joinNode.getFilter()
+                                .map(SymbolsExtractor::extractUnique)
+                                .orElse(ImmutableSet.of()))
+                .build();
+
+        Set leftUsableInputs = ImmutableSet.builder()
+                .addAll(globallyUsableInputs)
+                .addAll(
+                        joinNode.getCriteria().stream()
+                                .map(JoinNode.EquiJoinClause::getLeft)
+                                .iterator())
+                .addAll(joinNode.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()))
+                .build();
+
+        Set rightUsableInputs = ImmutableSet.builder()
+                .addAll(globallyUsableInputs)
+                .addAll(
+                        joinNode.getCriteria().stream()
+                                .map(JoinNode.EquiJoinClause::getRight)
+                                .iterator())
+                .addAll(joinNode.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()))
+                .build();
+
+        return restrictChildOutputs(idAllocator, joinNode, leftUsableInputs, rightUsableInputs);
+    }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java
new file mode 100644
index 0000000000000..f2d4276ea28ca
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneJoinColumns.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.google.common.collect.ImmutableList;
+
+import java.util.Optional;
+import java.util.Set;
+
+import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+
+/**
+ * Non-cross joins support output symbol selection, so absorb any project-off into the node.
+ */
+public class PruneJoinColumns
+        implements Rule
+{
+    private static final Pattern PATTERN = Pattern.node(ProjectNode.class);
+
+    @Override
+    public Pattern getPattern()
+    {
+        return PATTERN;
+    }
+
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
+        ProjectNode parent = (ProjectNode) node;
+
+        PlanNode child = lookup.resolve(parent.getSource());
+        if (!(child instanceof JoinNode)) {
+            return Optional.empty();
+        }
+
+        JoinNode joinNode = (JoinNode) child;
+        if (joinNode.isCrossJoin()) {
+            return Optional.empty();
+        }
+
+        Optional> dependencies = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions());
+        if (!dependencies.isPresent()) {
+            return Optional.empty();
+        }
+
+        return Optional.of(
+                parent.replaceChildren(ImmutableList.of(
+                        new JoinNode(
+                                joinNode.getId(),
+                                joinNode.getType(),
+                                joinNode.getLeft(),
+                                joinNode.getRight(),
+                                joinNode.getCriteria(),
+                                joinNode.getOutputSymbols().stream()
+                                        .filter(dependencies.get()::contains)
+                                        .collect(toImmutableList()),
+                                joinNode.getFilter(),
+                                joinNode.getLeftHashSymbol(),
+                                joinNode.getRightHashSymbol(),
+                                joinNode.getDistributionType()))));
+    }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java
new file mode 100644
index 0000000000000..9f4766e7a5985
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMarkDistinctColumns.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Streams;
+
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Stream;
+
+import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs;
+import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+
+public class PruneMarkDistinctColumns
+        implements Rule
+{
+    private static final Pattern PATTERN = Pattern.node(ProjectNode.class);
+
+    @Override
+    public Pattern getPattern()
+    {
+        return PATTERN;
+    }
+
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
+        ProjectNode parent = (ProjectNode) node;
+
+        PlanNode child = lookup.resolve(parent.getSource());
+        if (!(child instanceof MarkDistinctNode)) {
+            return Optional.empty();
+        }
+
+        MarkDistinctNode markDistinctNode = (MarkDistinctNode) child;
+
+        Optional> prunedOutputs = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions());
+        if (!prunedOutputs.isPresent()) {
+            return Optional.empty();
+        }
+
+        if (!prunedOutputs.get().contains(markDistinctNode.getMarkerSymbol())) {
+            return Optional.of(
+                    node.replaceChildren(ImmutableList.of(markDistinctNode.getSource())));
+        }
+
+        Set requiredInputs = Streams.concat(
+                prunedOutputs.get().stream()
+                        .filter(symbol -> !symbol.equals(markDistinctNode.getMarkerSymbol())),
+                markDistinctNode.getDistinctSymbols().stream(),
+                markDistinctNode.getHashSymbol().map(Stream::of).orElse(Stream.empty()))
+                .collect(toImmutableSet());
+
+        return restrictOutputs(idAllocator, markDistinctNode.getSource(), requiredInputs)
+                .map(prunedMarkDistinctSource ->
+                        parent.replaceChildren(ImmutableList.of(
+                                markDistinctNode.replaceChildren(ImmutableList.of(
+                                        prunedMarkDistinctSource)))));
+    }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java
new file mode 100644
index 0000000000000..16f1cb34dcfda
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinColumns.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Streams;
+
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Stream;
+
+import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs;
+import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+
+public class PruneSemiJoinColumns
+        implements Rule
+{
+    private static final Pattern PATTERN = Pattern.node(ProjectNode.class);
+
+    @Override
+    public Pattern getPattern()
+    {
+        return PATTERN;
+    }
+
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
+        ProjectNode parent = (ProjectNode) node;
+
+        PlanNode child = lookup.resolve(parent.getSource());
+        if (!(child instanceof SemiJoinNode)) {
+            return Optional.empty();
+        }
+
+        SemiJoinNode semiJoinNode = (SemiJoinNode) child;
+
+        Optional> prunedOutputs = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions());
+        if (!prunedOutputs.isPresent()) {
+            return Optional.empty();
+        }
+
+        if (!prunedOutputs.get().contains(semiJoinNode.getSemiJoinOutput())) {
+            return Optional.of(
+                    parent.replaceChildren(ImmutableList.of(semiJoinNode.getSource())));
+        }
+
+        Set requiredSourceInputs = Streams.concat(
+                prunedOutputs.get().stream()
+                        .filter(symbol -> !symbol.equals(semiJoinNode.getSemiJoinOutput())),
+                Stream.of(semiJoinNode.getSourceJoinSymbol()),
+                semiJoinNode.getSourceHashSymbol().map(Stream::of).orElse(Stream.empty()))
+                .collect(toImmutableSet());
+
+        return restrictOutputs(idAllocator, semiJoinNode.getSource(), requiredSourceInputs)
+                .map(newSource ->
+                        parent.replaceChildren(ImmutableList.of(
+                                semiJoinNode.replaceChildren(ImmutableList.of(
+                                        newSource, semiJoinNode.getFilteringSource())))));
+    }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java
new file mode 100644
index 0000000000000..071a5965756bb
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneSemiJoinFilteringSourceColumns.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Streams;
+
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Stream;
+
+import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+
+public class PruneSemiJoinFilteringSourceColumns
+        implements Rule
+{
+    private static final Pattern PATTERN = Pattern.node(SemiJoinNode.class);
+
+    @Override
+    public Pattern getPattern()
+    {
+        return PATTERN;
+    }
+
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
+        SemiJoinNode semiJoinNode = (SemiJoinNode) node;
+
+        Set requiredFilteringSourceInputs = Streams.concat(
+                Stream.of(semiJoinNode.getFilteringSourceJoinSymbol()),
+                semiJoinNode.getFilteringSourceHashSymbol().map(Stream::of).orElse(Stream.empty()))
+                .collect(toImmutableSet());
+
+        return restrictOutputs(idAllocator, semiJoinNode.getFilteringSource(), requiredFilteringSourceInputs)
+                .map(newFilteringSource ->
+                        semiJoinNode.replaceChildren(ImmutableList.of(
+                                semiJoinNode.getSource(), newFilteringSource)));
+    }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java
index 206e74eb8eb9c..3e654efc9a942 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableScanColumns.java
@@ -18,6 +18,7 @@
 import com.facebook.presto.sql.planner.Symbol;
 import com.facebook.presto.sql.planner.SymbolAllocator;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.PlanNode;
 import com.facebook.presto.sql.planner.plan.ProjectNode;
@@ -25,21 +26,27 @@
 
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs;
+import static com.google.common.collect.ImmutableList.toImmutableList;
 
 public class PruneTableScanColumns
         implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(ProjectNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof ProjectNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         ProjectNode parent = (ProjectNode) node;
 
         PlanNode source = lookup.resolve(parent.getSource());
@@ -49,12 +56,15 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
 
         TableScanNode child = (TableScanNode) source;
 
-        Optional> dependencies = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions());
+        Optional> dependencies = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions());
         if (!dependencies.isPresent()) {
             return Optional.empty();
         }
 
-        List newOutputs = dependencies.get();
+        List newOutputs = child.getOutputSymbols().stream()
+                .filter(dependencies.get()::contains)
+                .collect(toImmutableList());
+
         return Optional.of(
                 new ProjectNode(
                         parent.getId(),
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java
index 002afd4265aeb..ac85cfe0b4ff6 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneValuesColumns.java
@@ -18,6 +18,7 @@
 import com.facebook.presto.sql.planner.Symbol;
 import com.facebook.presto.sql.planner.SymbolAllocator;
 import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
 import com.facebook.presto.sql.planner.iterative.Rule;
 import com.facebook.presto.sql.planner.plan.PlanNode;
 import com.facebook.presto.sql.planner.plan.ProjectNode;
@@ -28,20 +29,26 @@
 import java.util.Arrays;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs;
+import static com.google.common.collect.ImmutableList.toImmutableList;
 
 public class PruneValuesColumns
         implements Rule
 {
+    private static final Pattern PATTERN = Pattern.node(ProjectNode.class);
+
     @Override
-    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    public Pattern getPattern()
     {
-        if (!(node instanceof ProjectNode)) {
-            return Optional.empty();
-        }
+        return PATTERN;
+    }
 
+    @Override
+    public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session)
+    {
         ProjectNode parent = (ProjectNode) node;
 
         PlanNode child = lookup.resolve(parent.getSource());
@@ -51,12 +58,14 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato
 
         ValuesNode values = (ValuesNode) child;
 
-        Optional> dependencies = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions());
+        Optional> dependencies = pruneInputs(child.getOutputSymbols(), parent.getAssignments().getExpressions());
         if (!dependencies.isPresent()) {
             return Optional.empty();
         }
 
-        List newOutputs = dependencies.get();
+        List newOutputs = child.getOutputSymbols().stream()
+                .filter(dependencies.get()::contains)
+                .collect(toImmutableList());
 
         // for each output of project, the corresponding column in the values node
         int[] mapping = new int[newOutputs.size()];
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java
new file mode 100644
index 0000000000000..073ebbb74f048
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java
@@ -0,0 +1,315 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner.iterative.rule;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
+import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
+import com.facebook.presto.sql.planner.Symbol;
+import com.facebook.presto.sql.planner.SymbolAllocator;
+import com.facebook.presto.sql.planner.iterative.Lookup;
+import com.facebook.presto.sql.planner.iterative.Pattern;
+import com.facebook.presto.sql.planner.iterative.Rule;
+import com.facebook.presto.sql.planner.plan.AggregationNode;
+import com.facebook.presto.sql.planner.plan.Assignments;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.PlanNode;
+import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.facebook.presto.sql.planner.plan.ValuesNode;
+import com.facebook.presto.sql.tree.CoalesceExpression;
+import com.facebook.presto.sql.tree.Expression;
+import com.facebook.presto.sql.tree.FunctionCall;
+import com.facebook.presto.sql.tree.NullLiteral;
+import com.facebook.presto.sql.tree.SymbolReference;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.facebook.presto.SystemSessionProperties.shouldPushAggregationThroughJoin;
+import static com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+
+/**
+ * This optimizer pushes aggregations below outer joins when: the aggregation
+ * is on top of the outer join, it groups by all columns in the outer table, and
+ * the outer rows are guaranteed to be distinct.
+ * 

+ * When the aggregation is pushed down, we still need to perform aggregations + * on the null values that come out of the absent values in an outer + * join. We add a cross join with a row of aggregations on null literals, + * and coalesce the aggregation that results from the left outer join with + * the result of the aggregation over nulls. + *

+ * Example: + *

+ * - Filter ("nationkey" > "avg")
+ *  - Aggregate(Group by: all columns from the left table, aggregation:
+ *    avg("n2.nationkey"))
+ *      - LeftJoin("regionkey" = "regionkey")
+ *          - AssignUniqueId (nation)
+ *              - Tablescan (nation)
+ *          - Tablescan (nation)
+ * 
+ *

+ * Is rewritten to: + *
+ * - Filter ("nationkey" > "avg")
+ *  - project(regionkey, coalesce("avg", "avg_over_null")
+ *      - CrossJoin
+ *          - LeftJoin("regionkey" = "regionkey")
+ *              - AssignUniqueId (nation)
+ *                  - Tablescan (nation)
+ *              - Aggregate(Group by: regionkey, aggregation:
+ *                avg(nationkey))
+ *                  - Tablescan (nation)
+ *          - Aggregate
+ *            avg(null_literal)
+ *              - Values (null_literal)
+ * 
+ */ +public class PushAggregationThroughOuterJoin + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!shouldPushAggregationThroughJoin(session)) { + return Optional.empty(); + } + + if (!(node instanceof AggregationNode)) { + return Optional.empty(); + } + + AggregationNode aggregation = (AggregationNode) node; + PlanNode source = lookup.resolve(aggregation.getSource()); + if (!(source instanceof JoinNode)) { + return Optional.empty(); + } + JoinNode join = (JoinNode) source; + if (join.getFilter().isPresent() + || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) + || !groupsOnAllOuterTableColumns(aggregation, lookup.resolve(getOuterTable(join))) + || !isDistinct(lookup.resolve(getOuterTable(join)), lookup::resolve)) { + return Optional.empty(); + } + + List groupingKeys = join.getCriteria().stream() + .map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight) + .collect(toImmutableList()); + AggregationNode rewrittenAggregation = new AggregationNode( + node.getId(), + getInnerTable(join), + aggregation.getAggregations(), + ImmutableList.of(groupingKeys), + aggregation.getStep(), + aggregation.getHashSymbol(), + aggregation.getGroupIdSymbol()); + + JoinNode rewrittenJoin; + if (join.getType() == JoinNode.Type.LEFT) { + rewrittenJoin = new JoinNode( + join.getId(), + join.getType(), + join.getLeft(), + rewrittenAggregation, + join.getCriteria(), + ImmutableList.builder() + .addAll(join.getLeft().getOutputSymbols()) + .addAll(rewrittenAggregation.getAggregations().keySet()) + .build(), + join.getFilter(), + join.getLeftHashSymbol(), + join.getRightHashSymbol(), + join.getDistributionType()); + } + else { + rewrittenJoin = new JoinNode( + join.getId(), + join.getType(), + rewrittenAggregation, + join.getRight(), + join.getCriteria(), + ImmutableList.builder() + .addAll(rewrittenAggregation.getAggregations().keySet()) + .addAll(join.getRight().getOutputSymbols()) + .build(), + join.getFilter(), + join.getLeftHashSymbol(), + join.getRightHashSymbol(), + join.getDistributionType()); + } + + return Optional.of(coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, symbolAllocator, idAllocator, lookup)); + } + + private static PlanNode getInnerTable(JoinNode join) + { + checkState(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN"); + PlanNode innerNode; + if (join.getType().equals(JoinNode.Type.LEFT)) { + innerNode = join.getRight(); + } + else { + innerNode = join.getLeft(); + } + return innerNode; + } + + private static PlanNode getOuterTable(JoinNode join) + { + checkState(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN"); + PlanNode outerNode; + if (join.getType().equals(JoinNode.Type.LEFT)) { + outerNode = join.getLeft(); + } + else { + outerNode = join.getRight(); + } + return outerNode; + } + + private static boolean groupsOnAllOuterTableColumns(AggregationNode node, PlanNode outerTable) + { + return new HashSet<>(node.getGroupingKeys()).equals(new HashSet<>(outerTable.getOutputSymbols())); + } + + // When the aggregation is done after the join, there will be a null value that gets aggregated over + // where rows did not exist in the inner table. For some aggregate functions, such as count, the result + // of an aggregation over a single null row is one or zero rather than null. In order to ensure correct results, + // we add a coalesce function with the output of the new outer join and the agggregation performed over a single + // null row. + private PlanNode coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) + { + // Create an aggregation node over a row of nulls. + MappedAggregationInfo aggregationOverNullInfo = createAggregationOverNull(aggregationNode, symbolAllocator, idAllocator, lookup); + AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation(); + Map sourceAggregationToOverNullMapping = aggregationOverNullInfo.getSymbolMapping(); + + // Do a cross join with the aggregation over null + JoinNode crossJoin = new JoinNode( + idAllocator.getNextId(), + JoinNode.Type.INNER, + outerJoin, + aggregationOverNull, + ImmutableList.of(), + ImmutableList.builder() + .addAll(outerJoin.getOutputSymbols()) + .addAll(aggregationOverNull.getOutputSymbols()) + .build(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + + // Add coalesce expressions for all aggregation functions + Assignments.Builder assignmentsBuilder = Assignments.builder(); + for (Symbol symbol : outerJoin.getOutputSymbols()) { + if (aggregationNode.getAggregations().containsKey(symbol)) { + assignmentsBuilder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); + } + else { + assignmentsBuilder.put(symbol, symbol.toSymbolReference()); + } + } + return new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build()); + } + + private MappedAggregationInfo createAggregationOverNull(AggregationNode referenceAggregation, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) + { + // Create a values node that consists of a single row of nulls. + // Map the output symbols from the referenceAggregation's source + // to symbol references for the new values node. + NullLiteral nullLiteral = new NullLiteral(); + ImmutableList.Builder nullSymbols = ImmutableList.builder(); + ImmutableList.Builder nullLiterals = ImmutableList.builder(); + ImmutableMap.Builder sourcesSymbolMappingBuilder = ImmutableMap.builder(); + for (Symbol sourceSymbol : lookup.resolve(referenceAggregation.getSource()).getOutputSymbols()) { + nullLiterals.add(nullLiteral); + Symbol nullSymbol = symbolAllocator.newSymbol(nullLiteral, symbolAllocator.getTypes().get(sourceSymbol)); + nullSymbols.add(nullSymbol); + sourcesSymbolMappingBuilder.put(sourceSymbol, nullSymbol.toSymbolReference()); + } + ValuesNode nullRow = new ValuesNode( + idAllocator.getNextId(), + nullSymbols.build(), + ImmutableList.of(nullLiterals.build())); + Map sourcesSymbolMapping = sourcesSymbolMappingBuilder.build(); + + // For each aggregation function in the reference node, create a corresponding aggregation function + // that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the + // symbols in these new aggregations. + ImmutableMap.Builder aggregationsSymbolMappingBuilder = ImmutableMap.builder(); + ImmutableMap.Builder aggregationsOverNullBuilder = ImmutableMap.builder(); + for (Map.Entry entry : referenceAggregation.getAggregations().entrySet()) { + Symbol aggregationSymbol = entry.getKey(); + AggregationNode.Aggregation aggregation = entry.getValue(); + AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation( + (FunctionCall) new ExpressionSymbolInliner(sourcesSymbolMapping).rewrite(aggregation.getCall()), + aggregation.getSignature(), + aggregation.getMask().map(x -> Symbol.from(sourcesSymbolMapping.get(x)))); + Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getCall(), symbolAllocator.getTypes().get(aggregationSymbol)); + aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation); + aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol); + } + Map aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.build(); + + // create an aggregation node whose source is the null row. + AggregationNode aggregationOverNullRow = new AggregationNode( + idAllocator.getNextId(), + nullRow, + aggregationsOverNullBuilder.build(), + ImmutableList.of(ImmutableList.of()), + AggregationNode.Step.SINGLE, + Optional.empty(), + Optional.empty() + ); + return new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping); + } + + private static class MappedAggregationInfo + { + private final AggregationNode aggregationNode; + private final Map symbolMapping; + + public MappedAggregationInfo(AggregationNode aggregationNode, Map symbolMapping) + { + this.aggregationNode = aggregationNode; + this.symbolMapping = symbolMapping; + } + + public Map getSymbolMapping() + { + return symbolMapping; + } + + public AggregationNode getAggregation() + { + return aggregationNode; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java index e36accccc334d..9f2120ac06939 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughMarkDistinct.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; @@ -29,13 +30,17 @@ public class PushLimitThroughMarkDistinct implements Rule { + private static final Pattern PATTERN = Pattern.node(LimitNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof LimitNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { LimitNode parent = (LimitNode) node; PlanNode child = lookup.resolve(parent.getSource()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java index 0ff93ac54e80a..562413ee3c2a8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughProject.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,13 +30,17 @@ public class PushLimitThroughProject implements Rule { + private static final Pattern PATTERN = Pattern.node(LimitNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof LimitNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { LimitNode parent = (LimitNode) node; PlanNode child = lookup.resolve(parent.getSource()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java index 22676f9606035..8953b38b0da05 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushLimitThroughSemiJoin.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -29,13 +30,17 @@ public class PushLimitThroughSemiJoin implements Rule { + private static final Pattern PATTERN = Pattern.node(LimitNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof LimitNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { LimitNode parent = (LimitNode) node; PlanNode child = lookup.resolve(parent.getSource()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java new file mode 100644 index 0000000000000..543846469a342 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.ExpressionSymbolInliner; +import com.facebook.presto.sql.planner.PartitioningScheme; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; + +/** + * Transforms: + * + *
+ *  Project(x = e1, y = e2)
+ *    Exchange()
+ *      Source(a, b, c)
+ *  
+ * + * to: + * + *
+ *  Exchange()
+ *    Project(x = e1, y = e2)
+ *      Source(a, b, c)
+ *  
+ * + * Or if Exchange needs symbols from Source for partitioning or as hash symbol to: + * + *
+ *  Project(x, y)
+ *    Exchange()
+ *      Project(x = e1, y = e2, a)
+ *        Source(a, b, c)
+ *  
+ * + * + * To avoid looping this optimizer will not be fired if upper Project contains just symbol references. + */ +public class PushProjectionThroughExchange + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof ProjectNode)) { + return Optional.empty(); + } + + ProjectNode project = (ProjectNode) node; + + PlanNode child = lookup.resolve(project.getSource()); + if (!(child instanceof ExchangeNode)) { + return Optional.empty(); + } + + if (isSymbolToSymbolProjection(project)) { + return Optional.empty(); + } + + ExchangeNode exchange = (ExchangeNode) child; + + ImmutableList.Builder newSourceBuilder = ImmutableList.builder(); + ImmutableList.Builder> inputsBuilder = ImmutableList.builder(); + for (int i = 0; i < exchange.getSources().size(); i++) { + Map outputToInputMap = extractExchangeOutputToInput(exchange, i); + + Assignments.Builder projections = Assignments.builder(); + ImmutableList.Builder inputs = ImmutableList.builder(); + + // Need to retain the partition keys for the exchange + exchange.getPartitioningScheme().getPartitioning().getColumns().stream() + .map(outputToInputMap::get) + .forEach(nameReference -> { + Symbol symbol = Symbol.from(nameReference); + projections.put(symbol, nameReference); + inputs.add(symbol); + }); + + if (exchange.getPartitioningScheme().getHashColumn().isPresent()) { + // Need to retain the hash symbol for the exchange + projections.put(exchange.getPartitioningScheme().getHashColumn().get(), exchange.getPartitioningScheme().getHashColumn().get().toSymbolReference()); + inputs.add(exchange.getPartitioningScheme().getHashColumn().get()); + } + for (Map.Entry projection : project.getAssignments().entrySet()) { + Expression translatedExpression = translateExpression(projection.getValue(), outputToInputMap); + Type type = symbolAllocator.getTypes().get(projection.getKey()); + Symbol symbol = symbolAllocator.newSymbol(translatedExpression, type); + projections.put(symbol, translatedExpression); + inputs.add(symbol); + } + newSourceBuilder.add(new ProjectNode(idAllocator.getNextId(), exchange.getSources().get(i), projections.build())); + inputsBuilder.add(inputs.build()); + } + + // Construct the output symbols in the same order as the sources + ImmutableList.Builder outputBuilder = ImmutableList.builder(); + exchange.getPartitioningScheme().getPartitioning().getColumns() + .forEach(outputBuilder::add); + if (exchange.getPartitioningScheme().getHashColumn().isPresent()) { + outputBuilder.add(exchange.getPartitioningScheme().getHashColumn().get()); + } + for (Map.Entry projection : project.getAssignments().entrySet()) { + outputBuilder.add(projection.getKey()); + } + + // outputBuilder contains all partition and hash symbols so simply swap the output layout + PartitioningScheme partitioningScheme = new PartitioningScheme( + exchange.getPartitioningScheme().getPartitioning(), + outputBuilder.build(), + exchange.getPartitioningScheme().getHashColumn(), + exchange.getPartitioningScheme().isReplicateNullsAndAny(), + exchange.getPartitioningScheme().getBucketToPartition()); + + PlanNode result = new ExchangeNode( + exchange.getId(), + exchange.getType(), + exchange.getScope(), + partitioningScheme, + newSourceBuilder.build(), + inputsBuilder.build()); + + // we need to strip unnecessary symbols (hash, partitioning columns). + return Optional.of(restrictOutputs(idAllocator, result, ImmutableSet.copyOf(project.getOutputSymbols())).orElse(result)); + } + + private boolean isSymbolToSymbolProjection(ProjectNode project) + { + return project.getAssignments().getExpressions().stream().allMatch(e -> e instanceof SymbolReference); + } + + private static Map extractExchangeOutputToInput(ExchangeNode exchange, int sourceIndex) + { + Map outputToInputMap = new HashMap<>(); + for (int i = 0; i < exchange.getOutputSymbols().size(); i++) { + outputToInputMap.put(exchange.getOutputSymbols().get(i), exchange.getInputs().get(sourceIndex).get(i).toSymbolReference()); + } + return outputToInputMap; + } + + private static Expression translateExpression(Expression inputExpression, Map symbolMapping) + { + return new ExpressionSymbolInliner(symbolMapping::get).rewrite(inputExpression); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java new file mode 100644 index 0000000000000..e203eff472a1c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.ExpressionSymbolInliner; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.UnionNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class PushProjectionThroughUnion + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof ProjectNode)) { + return Optional.empty(); + } + + ProjectNode parent = (ProjectNode) node; + + PlanNode child = lookup.resolve(parent.getSource()); + if (!(child instanceof UnionNode)) { + return Optional.empty(); + } + + UnionNode source = (UnionNode) child; + + // OutputLayout of the resultant Union, will be same as the layout of the Project + List outputLayout = node.getOutputSymbols(); + + // Mapping from the output symbol to ordered list of symbols from each of the sources + ImmutableListMultimap.Builder mappings = ImmutableListMultimap.builder(); + + // sources for the resultant UnionNode + ImmutableList.Builder outputSources = ImmutableList.builder(); + + for (int i = 0; i < child.getSources().size(); i++) { + Map outputToInput = source.sourceSymbolMap(i); // Map: output of union -> input of this source to the union + Assignments.Builder assignments = Assignments.builder(); // assignments for the new ProjectNode + + // mapping from current ProjectNode to new ProjectNode, used to identify the output layout + Map projectSymbolMapping = new HashMap<>(); + + // Translate the assignments in the ProjectNode using symbols of the source of the UnionNode + for (Map.Entry entry : parent.getAssignments().entrySet()) { + Expression translatedExpression = translateExpression(entry.getValue(), outputToInput); + Type type = symbolAllocator.getTypes().get(entry.getKey()); + Symbol symbol = symbolAllocator.newSymbol(translatedExpression, type); + assignments.put(symbol, translatedExpression); + projectSymbolMapping.put(entry.getKey(), symbol); + } + outputSources.add(new ProjectNode(idAllocator.getNextId(), source.getSources().get(i), assignments.build())); + outputLayout.forEach(symbol -> mappings.put(symbol, projectSymbolMapping.get(symbol))); + } + + return Optional.of(new UnionNode(node.getId(), outputSources.build(), mappings.build(), ImmutableList.copyOf(mappings.build().keySet()))); + } + + private static Expression translateExpression(Expression inputExpression, Map symbolMapping) + { + return new ExpressionSymbolInliner(symbolMapping::get).rewrite(inputExpression); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java new file mode 100644 index 0000000000000..280cde9c19e92 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushTopNThroughUnion.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.SymbolMapper; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TopNNode; +import com.facebook.presto.sql.planner.plan.UnionNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.plan.TopNNode.Step.PARTIAL; +import static com.google.common.collect.Iterables.getLast; +import static com.google.common.collect.Sets.intersection; + +public class PushTopNThroughUnion + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(TopNNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof TopNNode)) { + return Optional.empty(); + } + + TopNNode topNNode = (TopNNode) node; + + if (!topNNode.getStep().equals(PARTIAL)) { + return Optional.empty(); + } + + PlanNode child = lookup.resolve(topNNode.getSource()); + if (!(child instanceof UnionNode)) { + return Optional.empty(); + } + UnionNode unionNode = (UnionNode) child; + + ImmutableList.Builder sources = ImmutableList.builder(); + + for (PlanNode source : unionNode.getSources()) { + SymbolMapper.Builder symbolMapper = SymbolMapper.builder(); + Set sourceOutputSymbols = ImmutableSet.copyOf(source.getOutputSymbols()); + + for (Symbol unionOutput : unionNode.getOutputSymbols()) { + Set inputSymbols = ImmutableSet.copyOf(unionNode.getSymbolMapping().get(unionOutput)); + Symbol unionInput = getLast(intersection(inputSymbols, sourceOutputSymbols)); + symbolMapper.put(unionOutput, unionInput); + } + sources.add(symbolMapper.build().map(topNNode, source, idAllocator.getNextId())); + } + + return Optional.of(new UnionNode( + unionNode.getId(), + sources.build(), + unionNode.getSymbolMapping(), + unionNode.getOutputSymbols())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java index 5ca4ec64936fb..6f6527da4fa6f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveEmptyDelete.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -51,14 +52,19 @@ public class RemoveEmptyDelete implements Rule { + private static final Pattern PATTERN = Pattern.node(TableFinishNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + @Override public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) { // TODO split into multiple rules (https://github.com/prestodb/presto/issues/7292) - if (!(node instanceof TableFinishNode)) { - return Optional.empty(); - } TableFinishNode finish = (TableFinishNode) node; PlanNode finishSource = lookup.resolve(finish.getSource()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java index 41fe9fc76bf4c..bc7d079f1cc1a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveFullSample.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SampleNode; @@ -29,13 +30,17 @@ public class RemoveFullSample implements Rule { + private static final Pattern PATTERN = Pattern.node(SampleNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof SampleNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { SampleNode sample = (SampleNode) node; //noinspection FloatingPointEquality diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java index 0aa882ab19ac9..5b0605650aa04 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantIdentityProjections.java @@ -17,6 +17,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -30,13 +31,17 @@ public class RemoveRedundantIdentityProjections implements Rule { + private static final Pattern PATTERN = Pattern.node(ProjectNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof ProjectNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { ProjectNode project = (ProjectNode) node; if (!project.isIdentity()) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java new file mode 100644 index 0000000000000..d2e950f67017a --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; +import static java.util.Optional.empty; + +public class RemoveUnreferencedScalarLateralNodes + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + LateralJoinNode lateralJoinNode = (LateralJoinNode) node; + PlanNode input = lateralJoinNode.getInput(); + PlanNode subquery = lateralJoinNode.getSubquery(); + + if (isUnreferencedScalar(input, lookup)) { + return Optional.of(subquery); + } + + if (isUnreferencedScalar(subquery, lookup)) { + return Optional.of(input); + } + + return empty(); + } + + private boolean isUnreferencedScalar(PlanNode input, Lookup lookup) + { + return input.getOutputSymbols().isEmpty() && isScalar(input, lookup); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 14c1ca785e74a..46d94c84adea2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -44,13 +45,17 @@ public class SimplifyCountOverConstant implements Rule { + private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof AggregationNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { AggregationNode parent = (AggregationNode) node; PlanNode input = lookup.resolve(parent.getSource()); @@ -61,15 +66,15 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato ProjectNode child = (ProjectNode) input; boolean changed = false; - Map assignments = new LinkedHashMap<>(parent.getAssignments()); + Map aggregations = new LinkedHashMap<>(parent.getAggregations()); - for (Entry entry : parent.getAssignments().entrySet()) { + for (Entry entry : parent.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); AggregationNode.Aggregation aggregation = entry.getValue(); if (isCountOverConstant(aggregation, child.getAssignments())) { changed = true; - assignments.put(symbol, new AggregationNode.Aggregation( + aggregations.put(symbol, new AggregationNode.Aggregation( new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT)), aggregation.getMask())); @@ -83,7 +88,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.of(new AggregationNode( node.getId(), child, - assignments, + aggregations, parent.getGroupingSets(), parent.getStep(), parent.getHashSymbol(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java index 1612921f3c406..e841e849931a5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleMarkDistinctToGroupBy.java @@ -18,8 +18,10 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.tree.FunctionCall; @@ -27,13 +29,16 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; +import static com.google.common.collect.ImmutableList.toImmutableList; /** * Converts Single Distinct Aggregation into GroupBy @@ -48,13 +53,17 @@ public class SingleMarkDistinctToGroupBy implements Rule { + private static final Pattern PATTERN = Pattern.node(AggregationNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof AggregationNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { AggregationNode parent = (AggregationNode) node; PlanNode source = lookup.resolve(parent.getSource()); @@ -65,6 +74,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato MarkDistinctNode child = (MarkDistinctNode) source; boolean hasFilters = parent.getAggregations().values().stream() + .map(Aggregation::getCall) .map(FunctionCall::getFilter) .anyMatch(Optional::isPresent); @@ -75,12 +85,21 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato // optimize if and only if // all aggregation functions have a single common distinct mask symbol // AND all aggregation functions have mask - Set masks = ImmutableSet.copyOf(parent.getMasks().values()); - if (masks.size() != 1 || parent.getMasks().size() != parent.getAggregations().size()) { + Collection aggregations = parent.getAggregations().values(); + + List masks = aggregations.stream() + .map(Aggregation::getMask) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableList()); + + Set uniqueMasks = ImmutableSet.copyOf(masks); + + if (uniqueMasks.size() != 1 || masks.size() != aggregations.size()) { return Optional.empty(); } - Symbol mask = Iterables.getOnlyElement(masks); + Symbol mask = Iterables.getOnlyElement(uniqueMasks); if (!child.getMarkerSymbol().equals(mask)) { return Optional.empty(); @@ -98,7 +117,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato child.getHashSymbol(), Optional.empty()), // remove DISTINCT flag from function calls - parent.getAssignments() + parent.getAggregations() .entrySet().stream() .collect(Collectors.toMap( Map.Entry::getKey, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsByPartitionsOrder.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java similarity index 55% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsByPartitionsOrder.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java index 3f277076b7018..460f60612200c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsByPartitionsOrder.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SwapAdjacentWindowsBySpecifications.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -27,17 +27,22 @@ import java.util.Optional; import static com.facebook.presto.sql.planner.iterative.rule.Util.transpose; +import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.dependsOn; -public class SwapAdjacentWindowsByPartitionsOrder +public class SwapAdjacentWindowsBySpecifications implements Rule { + private static final Pattern PATTERN = Pattern.node(WindowNode.class); + @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof WindowNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { WindowNode parent = (WindowNode) node; PlanNode child = lookup.resolve(parent.getSource()); @@ -53,18 +58,23 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato } } - private static boolean dependsOn(WindowNode parent, WindowNode child) + private static int compare(WindowNode o1, WindowNode o2) { - return parent.getPartitionBy().stream().anyMatch(child.getCreatedSymbols()::contains) - || parent.getOrderBy().stream().anyMatch(child.getCreatedSymbols()::contains) - || parent.getWindowFunctions().values().stream() - .map(WindowNode.Function::getFunctionCall) - .map(DependencyExtractor::extractUnique) - .flatMap(symbols -> symbols.stream()) - .anyMatch(child.getCreatedSymbols()::contains); + int comparison = comparePartitionBy(o1, o2); + if (comparison != 0) { + return comparison; + } + + comparison = compareOrderBy(o1, o2); + if (comparison != 0) { + return comparison; + } + + // If PartitionBy and OrderBy clauses are identical, let's establish an arbitrary order to prevent non-deterministic results of swapping WindowNodes in such a case + return o1.getId().toString().compareTo(o2.getId().toString()); } - private static int compare(WindowNode o1, WindowNode o2) + private static int comparePartitionBy(WindowNode o1, WindowNode o2) { Iterator iterator1 = o1.getPartitionBy().iterator(); Iterator iterator2 = o2.getPartitionBy().iterator(); @@ -73,21 +83,48 @@ private static int compare(WindowNode o1, WindowNode o2) Symbol symbol1 = iterator1.next(); Symbol symbol2 = iterator2.next(); - int comparison = symbol1.compareTo(symbol2); - if (comparison != 0) { - return comparison; + int partitionByComparison = symbol1.compareTo(symbol2); + if (partitionByComparison != 0) { + return partitionByComparison; } } if (iterator1.hasNext()) { return 1; } - if (iterator2.hasNext()) { return -1; } + return 0; + } + + private static int compareOrderBy(WindowNode o1, WindowNode o2) + { + Iterator iterator1 = o1.getOrderBy().iterator(); + Iterator iterator2 = o2.getOrderBy().iterator(); + + while (iterator1.hasNext() && iterator2.hasNext()) { + Symbol symbol1 = iterator1.next(); + Symbol symbol2 = iterator2.next(); + + int orderByComparison = symbol1.compareTo(symbol2); + if (orderByComparison != 0) { + return orderByComparison; + } + else { + int sortOrderComparison = o1.getOrderings().get(symbol1).compareTo(o2.getOrderings().get(symbol2)); + if (sortOrderComparison != 0) { + return sortOrderComparison; + } + } + } - // If both are equal, let's establish an arbitrary order to prevent non-deterministic results of swapping WindowNodes with identical PartitionBy clauses - return o1.getId().toString().compareTo(o2.getId().toString()); + if (iterator1.hasNext()) { + return 1; + } + if (iterator2.hasNext()) { + return -1; + } + return 0; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java new file mode 100644 index 0000000000000..1586de8483a69 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -0,0 +1,427 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.TransformCorrelatedScalarAggregationToJoin; +import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedInPredicateSubqueryToSemiJoin; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanVisitor; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.BooleanLiteral; +import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.ComparisonExpressionType; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.IsNotNullPredicate; +import com.facebook.presto.sql.tree.IsNullPredicate; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.NotExpression; +import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SearchedCaseExpression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.sql.tree.WhenClause; +import com.facebook.presto.sql.tree.Window; +import com.facebook.presto.sql.util.AstUtils; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.ExpressionUtils.or; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +/** + * Replaces correlated ApplyNode with InPredicate expression with SemiJoin + *

+ * Transforms: + *

+ * - Apply (output: a in B.b)
+ *    - input: some plan A producing symbol a
+ *    - subquery: some plan B producing symbol b, using symbols from A
+ * 
+ * Into: + *
+ * - Project (output: CASE WHEN (countmatches > 0) THEN true WHEN (countnullmatches > 0) THEN null ELSE false END)
+ *   - Aggregate (countmatches=count(*) where a, b not null; countnullmatches where a,b null but buildSideKnownNonNull is not null)
+ *     grouping by (A'.*)
+ *     - LeftJoin on (A and B correlation condition)
+ *       - AssignUniqueId (A')
+ *         - A
+ * 
+ *

+ * + * @see TransformCorrelatedScalarAggregationToJoin + * @see TransformUncorrelatedInPredicateSubqueryToSemiJoin + */ +public class TransformCorrelatedInPredicateToJoin + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(ApplyNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof ApplyNode)) { + return Optional.empty(); + } + + ApplyNode apply = (ApplyNode) node; + + if (apply.getCorrelation().isEmpty()) { + return Optional.empty(); + } + + Assignments subqueryAssignments = apply.getSubqueryAssignments(); + if (subqueryAssignments.size() != 1) { + return Optional.empty(); + } + Expression assignmentExpression = getOnlyElement(subqueryAssignments.getExpressions()); + if (!(assignmentExpression instanceof InPredicate)) { + return Optional.empty(); + } + + InPredicate inPredicate = (InPredicate) assignmentExpression; + Symbol inPredicateOutputSymbol = getOnlyElement(subqueryAssignments.getSymbols()); + + return apply(apply, inPredicate, inPredicateOutputSymbol, lookup, idAllocator, symbolAllocator); + } + + private Optional apply( + ApplyNode apply, + InPredicate inPredicate, + Symbol inPredicateOutputSymbol, + Lookup lookup, + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator) + { + Optional decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation()) + .decorrelate(apply.getSubquery()); + + if (!decorrelated.isPresent()) { + return Optional.empty(); + } + + PlanNode projection = buildInPredicateEquivalent( + apply, + inPredicate, + inPredicateOutputSymbol, + decorrelated.get(), + idAllocator, + symbolAllocator + ); + + return Optional.of(projection); + } + + private PlanNode buildInPredicateEquivalent( + ApplyNode apply, + InPredicate inPredicate, + Symbol inPredicateOutputSymbol, + Decorrelated decorrelated, + PlanNodeIdAllocator idAllocator, + SymbolAllocator symbolAllocator) + { + Expression correlationCondition = and(decorrelated.getCorrelatedPredicates()); + PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode(); + + AssignUniqueId probeSide = new AssignUniqueId( + idAllocator.getNextId(), + apply.getInput(), + symbolAllocator.newSymbol("unique", BIGINT)); + + Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT); + ProjectNode buildSide = new ProjectNode( + idAllocator.getNextId(), + decorrelatedBuildSource, + Assignments.builder() + .putAll(Assignments.identity(decorrelatedBuildSource.getOutputSymbols())) + .put(buildSideKnownNonNull, bigint(0)) + .build() + ); + + Symbol probeSideSymbol = Symbol.from(inPredicate.getValue()); + Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList()); + + Expression joinExpression = and( + or( + new IsNullPredicate(probeSideSymbol.toSymbolReference()), + new ComparisonExpression(ComparisonExpressionType.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), + new IsNullPredicate(buildSideSymbol.toSymbolReference()) + ), + correlationCondition + ); + + JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression); + + Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT); + Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT); + + Expression matchCondition = and( + isNotNull(probeSideSymbol), + isNotNull(buildSideSymbol) + ); + + Expression nullMatchCondition = and( + isNotNull(buildSideKnownNonNull), + not(matchCondition) + ); + + AggregationNode aggregation = new AggregationNode( + idAllocator.getNextId(), + leftOuterJoin, + ImmutableMap.builder() + .put(countMatchesSymbol, countWithFilter(matchCondition)) + .put(countNullMatchesSymbol, countWithFilter(nullMatchCondition)) + .build(), + ImmutableList.of(probeSide.getOutputSymbols()), + AggregationNode.Step.SINGLE, + Optional.empty(), + Optional.empty() + ); + + // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results + SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression( + ImmutableList.of( + new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), + new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null)) + ), + Optional.of(booleanConstant(false)) + ); + return new ProjectNode( + idAllocator.getNextId(), + aggregation, + Assignments.builder() + .putAll(Assignments.identity(apply.getInput().getOutputSymbols())) + .put(inPredicateOutputSymbol, inPredicateEquivalent) + .build() + ); + } + + private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUniqueId probeSide, ProjectNode buildSide, Expression joinExpression) + { + return new JoinNode( + idAllocator.getNextId(), + JoinNode.Type.LEFT, + probeSide, + buildSide, + ImmutableList.of(), + ImmutableList.builder() + .addAll(probeSide.getOutputSymbols()) + .addAll(buildSide.getOutputSymbols()) + .build(), + Optional.of(joinExpression), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + + private static AggregationNode.Aggregation countWithFilter(Expression condition) + { + FunctionCall countCall = new FunctionCall( + QualifiedName.of("count"), + Optional.empty(), + Optional.of(condition), + false, + ImmutableList.of() /* arguments */ + ); + + return new AggregationNode.Aggregation( + countCall, + new Signature("count", FunctionKind.AGGREGATE, BIGINT.getTypeSignature()), + Optional.empty() /* mask */ + ); + } + + private static Expression isGreaterThan(Symbol symbol, long value) + { + return new ComparisonExpression( + ComparisonExpressionType.GREATER_THAN, + symbol.toSymbolReference(), + bigint(value) + ); + } + + private static Expression not(Expression booleanExpression) + { + return new NotExpression(booleanExpression); + } + + private static Expression isNotNull(Symbol symbol) + { + return new IsNotNullPredicate(symbol.toSymbolReference()); + } + + private static Expression bigint(long value) + { + return new Cast(new LongLiteral(String.valueOf(value)), BIGINT.toString()); + } + + private static Expression booleanConstant(@Nullable Boolean value) + { + if (value == null) { + return new Cast(new NullLiteral(), BOOLEAN.toString()); + } + return new BooleanLiteral(value.toString()); + } + + /** + * TODO consult comon parts with {@link com.facebook.presto.sql.planner.optimizations.TransformCorrelatedScalarAggregationToJoin.Rewriter#decorrelateFilters} + */ + private static class DecorrelatingVisitor + extends PlanVisitor, PlanNode> + { + private final Lookup lookup; + private final Set correlation; + + public DecorrelatingVisitor(Lookup lookup, Iterable correlation) + { + this.lookup = requireNonNull(lookup, "lookup is null"); + this.correlation = ImmutableSet.copyOf(requireNonNull(correlation, "correlation is null")); + } + + public Optional decorrelate(PlanNode reference) + { + return lookup.resolve(reference).accept(this, reference); + } + + @Override + public Optional visitProject(ProjectNode node, PlanNode reference) + { + if (isCorrelatedShallowly(node)) { + // TODO: handle correlated projection + return Optional.empty(); + } + + Optional result = decorrelate(node.getSource()); + return result.map(decorrelated -> { + Assignments.Builder assignments = Assignments.builder() + .putAll(node.getAssignments()); + + // Pull up all symbols used by a filter (except correlation) + decorrelated.getCorrelatedPredicates().stream() + .flatMap(AstUtils::preOrder) + .filter(SymbolReference.class::isInstance) + .map(SymbolReference.class::cast) + .filter(symbolReference -> !correlation.contains(Symbol.from(symbolReference))) + .forEach(symbolReference -> assignments.putIdentity(Symbol.from(symbolReference))); + + return new Decorrelated( + decorrelated.getCorrelatedPredicates(), + new ProjectNode( + node.getId(), // FIXME should I reuse or not? + decorrelated.getDecorrelatedNode(), + assignments.build() + ) + ); + }); + } + + @Override + public Optional visitFilter(FilterNode node, PlanNode reference) + { + Optional result = decorrelate(node.getSource()); + return result.map(decorrelated -> + new Decorrelated( + ImmutableList.builder() + .addAll(decorrelated.getCorrelatedPredicates()) + // No need to retain uncorrelated conditions, predicate push down will push them back + .add(node.getPredicate()) + .build(), + decorrelated.getDecorrelatedNode() + )); + } + + @Override + protected Optional visitPlan(PlanNode node, PlanNode reference) + { + if (isCorrelatedRecursively(node)) { + return Optional.empty(); + } + else { + return Optional.of(new Decorrelated(ImmutableList.of(), reference)); + } + } + + private boolean isCorrelatedRecursively(PlanNode node) + { + if (isCorrelatedShallowly(node)) { + return true; + } + return node.getSources().stream() + .map(lookup::resolve) + .anyMatch(this::isCorrelatedRecursively); + } + + private boolean isCorrelatedShallowly(PlanNode node) + { + return SymbolsExtractor.extractUniqueNonRecursive(node).stream().anyMatch(correlation::contains); + } + } + + private static class Decorrelated + { + private final List correlatedPredicates; + private final PlanNode decorrelatedNode; + + public Decorrelated(List correlatedPredicates, PlanNode decorrelatedNode) + { + this.correlatedPredicates = ImmutableList.copyOf(requireNonNull(correlatedPredicates, "correlatedPredicates is null")); + this.decorrelatedNode = requireNonNull(decorrelatedNode, "decorrelatedNode is null"); + } + + public List getCorrelatedPredicates() + { + return correlatedPredicates; + } + + public PlanNode getDecorrelatedNode() + { + return decorrelatedNode; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java new file mode 100644 index 0000000000000..9ddf7ce43f6ee --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.ScalarAggregationToJoinRewriter; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; +import static java.util.Objects.requireNonNull; + +public class TransformCorrelatedScalarAggregationToJoin + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + private final FunctionRegistry functionRegistry; + + public TransformCorrelatedScalarAggregationToJoin(FunctionRegistry functionRegistry) + { + this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null"); + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof LateralJoinNode)) { + return Optional.empty(); + } + + LateralJoinNode lateralJoinNode = (LateralJoinNode) node; + PlanNode subquery = lookup.resolve(lateralJoinNode.getSubquery()); + + if (lateralJoinNode.getCorrelation().isEmpty() || !(isScalar(subquery, lookup))) { + return Optional.empty(); + } + + Optional aggregation = findAggregation(subquery, lookup); + if (!(aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty())) { + return Optional.empty(); + } + + ScalarAggregationToJoinRewriter rewriter = new ScalarAggregationToJoinRewriter(functionRegistry, symbolAllocator, idAllocator, lookup); + + PlanNode rewrittenNode = rewriter.rewriteScalarAggregation(lateralJoinNode, aggregation.get()); + + if (rewrittenNode instanceof LateralJoinNode) { + return Optional.empty(); + } + + return Optional.of(rewrittenNode); + } + + private static Optional findAggregation(PlanNode rootNode, Lookup lookup) + { + return searchFrom(rootNode, lookup) + .where(AggregationNode.class::isInstance) + .skipOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) + .findFirst(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToScalarApply.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java similarity index 88% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToScalarApply.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 243b7be15f761..3b0e82e8a4230 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToScalarApply.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -20,11 +20,13 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Cast; @@ -40,6 +42,7 @@ import java.util.Optional; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.INNER; import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; @@ -53,26 +56,29 @@ * -- subquery *

*/ -public class TransformExistsApplyToScalarApply +public class TransformExistsApplyToLateralNode implements Rule { + private static final Pattern PATTERN = Pattern.node(ApplyNode.class); private static final QualifiedName COUNT = QualifiedName.of("count"); private static final FunctionCall COUNT_CALL = new FunctionCall(COUNT, ImmutableList.of()); private final Signature countSignature; - public TransformExistsApplyToScalarApply(FunctionRegistry functionRegistry) + public TransformExistsApplyToLateralNode(FunctionRegistry functionRegistry) { requireNonNull(functionRegistry, "functionRegistry is null"); countSignature = functionRegistry.resolveFunction(COUNT, ImmutableList.of()); } @Override - public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Pattern getPattern() { - if (!(node instanceof ApplyNode)) { - return Optional.empty(); - } + return PATTERN; + } + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { ApplyNode parent = (ApplyNode) node; if (parent.getSubqueryAssignments().size() != 1) { @@ -88,7 +94,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato Symbol exists = getOnlyElement(parent.getSubqueryAssignments().getSymbols()); return Optional.of( - new ApplyNode( + new LateralJoinNode( node.getId(), parent.getInput(), new ProjectNode( @@ -102,7 +108,7 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato Optional.empty(), Optional.empty()), Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))), - Assignments.of(exists, exists.toSymbolReference()), - parent.getCorrelation())); + parent.getCorrelation(), + INNER)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java new file mode 100644 index 0000000000000..1bb5fb31b23cb --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.InPredicate; + +import java.util.Optional; + +import static com.google.common.collect.Iterables.getOnlyElement; + +/** + * This optimizers looks for InPredicate expressions in ApplyNodes and replaces the nodes with SemiJoin nodes. + *

+ * Plan before optimizer: + *

+ * Filter(a IN b):
+ *   Apply
+ *     - correlation: []  // empty
+ *     - input: some plan A producing symbol a
+ *     - subquery: some plan B producing symbol b
+ * 
+ *

+ * Plan after optimizer: + *

+ * Filter(semijoinresult):
+ *   SemiJoin
+ *     - source: plan A
+ *     - filteringSource: symbol a
+ *     - sourceJoinSymbol: plan B
+ *     - filteringSourceJoinSymbol: symbol b
+ *     - semiJoinOutput: semijoinresult
+ * 
+ */ +public class TransformUncorrelatedInPredicateSubqueryToSemiJoin + implements Rule +{ + @Override + public Pattern getPattern() + { + return Pattern.node(ApplyNode.class); + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + ApplyNode applyNode = (ApplyNode) node; + + if (!applyNode.getCorrelation().isEmpty()) { + return Optional.empty(); + } + + if (applyNode.getSubqueryAssignments().size() != 1) { + return Optional.empty(); + } + + Expression expression = getOnlyElement(applyNode.getSubqueryAssignments().getExpressions()); + if (!(expression instanceof InPredicate)) { + return Optional.empty(); + } + + InPredicate inPredicate = (InPredicate) expression; + Symbol semiJoinSymbol = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols()); + + SemiJoinNode replacement = new SemiJoinNode(idAllocator.getNextId(), + applyNode.getInput(), + applyNode.getSubquery(), + Symbol.from(inPredicate.getValue()), + Symbol.from(inPredicate.getValueList()), + semiJoinSymbol, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + + return Optional.of(replacement); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java new file mode 100644 index 0000000000000..c3533093a5f5d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; + +import java.util.Optional; + +public class TransformUncorrelatedLateralToJoin + implements Rule +{ + private static final Pattern PATTERN = Pattern.node(LateralJoinNode.class); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + LateralJoinNode lateralJoinNode = (LateralJoinNode) node; + + if (!lateralJoinNode.getCorrelation().isEmpty()) { + return Optional.empty(); + } + + return Optional.of(new JoinNode( + idAllocator.getNextId(), + JoinNode.Type.INNER, + lateralJoinNode.getInput(), + lateralJoinNode.getSubquery(), + ImmutableList.of(), + ImmutableList.builder() + .addAll(lateralJoinNode.getInput().getOutputSymbols()) + .addAll(lateralJoinNode.getSubquery().getOutputSymbols()) + .build(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java index 5c61227ca478e..f6480bd0cc826 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java @@ -13,19 +13,25 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.DependencyExtractor; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import java.util.Collection; -import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; + class Util { private Util() @@ -33,24 +39,20 @@ private Util() } /** - * Prune the list of available inputs to those required by the given expressions. + * Prune the set of available inputs to those required by the given expressions. * * If all inputs are used, return Optional.empty() to indicate that no pruning is necessary. */ - public static Optional> pruneInputs(Collection availableInputs, Collection expressions) + public static Optional> pruneInputs(Collection availableInputs, Collection expressions) { - Set available = new HashSet<>(availableInputs); - Set required = DependencyExtractor.extractUnique(expressions); - - // we need to compute the intersection in case some dependencies are symbols from - // the outer scope (i.e., correlated queries) - Set used = Sets.intersection(required, available); - if (used.size() == available.size()) { - // no need to prune... every available input is being used + Set availableInputsSet = ImmutableSet.copyOf(availableInputs); + Set prunedInputs = Sets.filter(availableInputsSet, SymbolsExtractor.extractUnique(expressions)::contains); + + if (prunedInputs.size() == availableInputsSet.size()) { return Optional.empty(); } - return Optional.of(ImmutableList.copyOf(used)); + return Optional.of(prunedInputs); } /** @@ -62,4 +64,55 @@ public static PlanNode transpose(PlanNode parent, PlanNode child) parent.replaceChildren( child.getSources()))); } + + /** + * @return If the node has outputs not in permittedOutputs, returns an identity projection containing only those node outputs also in permittedOutputs. + */ + public static Optional restrictOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set permittedOutputs) + { + List restrictedOutputs = node.getOutputSymbols().stream() + .filter(permittedOutputs::contains) + .collect(toImmutableList()); + + if (restrictedOutputs.size() == node.getOutputSymbols().size()) { + return Optional.empty(); + } + + return Optional.of( + new ProjectNode( + idAllocator.getNextId(), + node, + Assignments.identity(restrictedOutputs))); + } + + /** + * @return The original node, with identity projections possibly inserted between node and each child, limiting the columns to those permitted. + * Returns a present Optional iff at least one child was rewritten. + */ + @SafeVarargs + public static Optional restrictChildOutputs(PlanNodeIdAllocator idAllocator, PlanNode node, Set... permittedChildOutputsArgs) + { + List> permittedChildOutputs = ImmutableList.copyOf(permittedChildOutputsArgs); + + checkArgument( + (node.getSources().size() == permittedChildOutputs.size()), + "Mismatched child (%d) and permitted outputs (%d) sizes", + node.getSources().size(), + permittedChildOutputs.size()); + + ImmutableList.Builder newChildrenBuilder = ImmutableList.builder(); + boolean rewroteChildren = false; + + for (int i = 0; i < node.getSources().size(); ++i) { + PlanNode oldChild = node.getSources().get(i); + Optional newChild = restrictOutputs(idAllocator, oldChild, permittedChildOutputs.get(i)); + rewroteChildren |= newChild.isPresent(); + newChildrenBuilder.add(newChild.orElse(oldChild)); + } + + if (!rewroteChildren) { + return Optional.empty(); + } + return Optional.of(node.replaceChildren(newChildrenBuilder.build())); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java index 26bacf0f7dc8a..338171761aae9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java @@ -90,9 +90,9 @@ public boolean isSingleNode() return global.isSingleNode(); } - public boolean isNullsReplicated() + public boolean isNullsAndAnyReplicated() { - return global.isNullsReplicated(); + return global.isNullsAndAnyReplicated(); } public boolean isStreamPartitionedOn(Collection columns) @@ -100,9 +100,9 @@ public boolean isStreamPartitionedOn(Collection columns) return isStreamPartitionedOn(columns, false); } - public boolean isStreamPartitionedOn(Collection columns, boolean nullsReplicated) + public boolean isStreamPartitionedOn(Collection columns, boolean nullsAndAnyReplicated) { - return global.isStreamPartitionedOn(columns, constants.keySet(), nullsReplicated); + return global.isStreamPartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); } public boolean isNodePartitionedOn(Collection columns) @@ -110,14 +110,14 @@ public boolean isNodePartitionedOn(Collection columns) return isNodePartitionedOn(columns, false); } - public boolean isNodePartitionedOn(Collection columns, boolean nullsReplicated) + public boolean isNodePartitionedOn(Collection columns, boolean nullsAndAnyReplicated) { - return global.isNodePartitionedOn(columns, constants.keySet(), nullsReplicated); + return global.isNodePartitionedOn(columns, constants.keySet(), nullsAndAnyReplicated); } - public boolean isNodePartitionedOn(Partitioning partitioning, boolean nullsReplicated) + public boolean isNodePartitionedOn(Partitioning partitioning, boolean nullsAndAnyReplicated) { - return global.isNodePartitionedOn(partitioning, nullsReplicated); + return global.isNodePartitionedOn(partitioning, nullsAndAnyReplicated); } public boolean isNodePartitionedWith(ActualProperties other, Function> symbolMappings) @@ -286,14 +286,14 @@ public static final class Global // the rows will be partitioned into a single node or stream. However, this can still be a partitioned plan in that the plan // will be executed on multiple servers, but only one server will get all the data. - // Description of whether rows with nulls in partitioning columns have been replicated to all *nodes* - private final boolean nullsReplicated; + // Description of whether rows with nulls in partitioning columns or some arbitrary rows have been replicated to all *nodes* + private final boolean nullsAndAnyReplicated; - private Global(Optional nodePartitioning, Optional streamPartitioning, boolean nullsReplicated) + private Global(Optional nodePartitioning, Optional streamPartitioning, boolean nullsAndAnyReplicated) { this.nodePartitioning = requireNonNull(nodePartitioning, "nodePartitioning is null"); this.streamPartitioning = requireNonNull(streamPartitioning, "streamPartitioning is null"); - this.nullsReplicated = nullsReplicated; + this.nullsAndAnyReplicated = nullsAndAnyReplicated; } public static Global coordinatorSingleStreamPartition() @@ -346,9 +346,9 @@ public Global withReplicatedNulls(boolean replicatedNulls) return new Global(nodePartitioning, streamPartitioning, replicatedNulls); } - private boolean isNullsReplicated() + private boolean isNullsAndAnyReplicated() { - return nullsReplicated; + return nullsAndAnyReplicated; } /** @@ -372,14 +372,14 @@ private boolean isCoordinatorOnly() return nodePartitioning.get().getHandle().isCoordinatorOnly(); } - private boolean isNodePartitionedOn(Collection columns, Set constants, boolean nullsReplicated) + private boolean isNodePartitionedOn(Collection columns, Set constants, boolean nullsAndAnyReplicated) { - return nodePartitioning.isPresent() && nodePartitioning.get().isPartitionedOn(columns, constants) && this.nullsReplicated == nullsReplicated; + return nodePartitioning.isPresent() && nodePartitioning.get().isPartitionedOn(columns, constants) && this.nullsAndAnyReplicated == nullsAndAnyReplicated; } - private boolean isNodePartitionedOn(Partitioning partitioning, boolean nullsReplicated) + private boolean isNodePartitionedOn(Partitioning partitioning, boolean nullsAndAnyReplicated) { - return nodePartitioning.isPresent() && nodePartitioning.get().equals(partitioning) && this.nullsReplicated == nullsReplicated; + return nodePartitioning.isPresent() && nodePartitioning.get().equals(partitioning) && this.nullsAndAnyReplicated == nullsAndAnyReplicated; } private boolean isNodePartitionedWith( @@ -395,7 +395,7 @@ private boolean isNodePartitionedWith( symbolMappings, leftConstantMapping, rightConstantMapping) && - nullsReplicated == other.nullsReplicated; + nullsAndAnyReplicated == other.nullsAndAnyReplicated; } private Optional getNodePartitioning() @@ -403,9 +403,9 @@ private Optional getNodePartitioning() return nodePartitioning; } - private boolean isStreamPartitionedOn(Collection columns, Set constants, boolean nullsReplicated) + private boolean isStreamPartitionedOn(Collection columns, Set constants, boolean nullsAndAnyReplicated) { - return streamPartitioning.isPresent() && streamPartitioning.get().isPartitionedOn(columns, constants) && this.nullsReplicated == nullsReplicated; + return streamPartitioning.isPresent() && streamPartitioning.get().isPartitionedOn(columns, constants) && this.nullsAndAnyReplicated == nullsAndAnyReplicated; } /** @@ -413,7 +413,7 @@ private boolean isStreamPartitionedOn(Collection columns, Set co */ private boolean isEffectivelySingleStream(Set constants) { - return streamPartitioning.isPresent() && streamPartitioning.get().isEffectivelySinglePartition(constants) && !nullsReplicated; + return streamPartitioning.isPresent() && streamPartitioning.get().isEffectivelySinglePartition(constants) && !nullsAndAnyReplicated; } /** @@ -421,7 +421,7 @@ private boolean isEffectivelySingleStream(Set constants) */ private boolean isStreamRepartitionEffective(Collection keys, Set constants) { - return (!streamPartitioning.isPresent() || streamPartitioning.get().isRepartitionEffective(keys, constants)) && !nullsReplicated; + return (!streamPartitioning.isPresent() || streamPartitioning.get().isRepartitionEffective(keys, constants)) && !nullsAndAnyReplicated; } private Global translate(Function> translator, Function> constants) @@ -429,13 +429,13 @@ private Global translate(Function> translator, Function return new Global( nodePartitioning.flatMap(partitioning -> partitioning.translate(translator, constants)), streamPartitioning.flatMap(partitioning -> partitioning.translate(translator, constants)), - nullsReplicated); + nullsAndAnyReplicated); } @Override public int hashCode() { - return Objects.hash(nodePartitioning, streamPartitioning, nullsReplicated); + return Objects.hash(nodePartitioning, streamPartitioning, nullsAndAnyReplicated); } @Override @@ -450,7 +450,7 @@ public boolean equals(Object obj) final Global other = (Global) obj; return Objects.equals(this.nodePartitioning, other.nodePartitioning) && Objects.equals(this.streamPartitioning, other.streamPartitioning) && - this.nullsReplicated == other.nullsReplicated; + this.nullsAndAnyReplicated == other.nullsAndAnyReplicated; } @Override @@ -459,7 +459,7 @@ public String toString() return toStringHelper(this) .add("nodePartitioning", nodePartitioning) .add("streamPartitioning", streamPartitioning) - .add("nullsReplicated", nullsReplicated) + .add("nullsAndAnyReplicated", nullsAndAnyReplicated) .toString(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 95dd8ce61a129..beb61008ba265 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -26,7 +26,6 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DomainTranslator; import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.LookupSymbolResolver; @@ -35,6 +34,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; @@ -48,6 +48,7 @@ import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; @@ -68,9 +69,9 @@ import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.SymbolReference; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -123,6 +124,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -180,7 +182,7 @@ List getCorrelations() } private class Rewriter - extends PlanVisitor + extends PlanVisitor { private final PlanNodeIdAllocator idAllocator; private final SymbolAllocator symbolAllocator; @@ -249,7 +251,11 @@ public PlanWithProperties visitAggregation(AggregationNode node, Context context { Set partitioningRequirement = ImmutableSet.copyOf(node.getGroupingKeys()); - PreferredProperties preferredProperties = PreferredProperties.any(); + boolean preferSingleNode = (node.hasEmptyGroupingSet() && !node.hasNonEmptyGroupingSet()) || + (node.hasDefaultOutput() && !node.isDecomposable(metadata.getFunctionRegistry())); + + PreferredProperties preferredProperties = preferSingleNode ? PreferredProperties.undistributed() : PreferredProperties.any(); + if (!node.getGroupingKeys().isEmpty()) { preferredProperties = PreferredProperties.partitionedWithLocal(partitioningRequirement, grouped(node.getGroupingKeys())) .mergeWithParent(context.getPreferredProperties()); @@ -262,8 +268,7 @@ public PlanWithProperties visitAggregation(AggregationNode node, Context context return rebaseAndDeriveProperties(node, child); } - if ((node.hasEmptyGroupingSet() && !node.hasNonEmptyGroupingSet()) || - (node.hasDefaultOutput() && !node.isDecomposable(metadata.getFunctionRegistry()))) { + if (preferSingleNode) { // For queries with only empty grouping sets like // // SELECT count(*) FROM lineitem; @@ -437,18 +442,23 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, Context con @Override public PlanWithProperties visitTopN(TopNNode node, Context context) { - PlanWithProperties child = planChild(node, context.withPreferredProperties(PreferredProperties.any())); - - if (!child.getProperties().isSingleNode()) { - child = withDerivedProperties( - new TopNNode(idAllocator.getNextId(), child.getNode(), node.getCount(), node.getOrderBy(), node.getOrderings(), true), - child.getProperties()); - - child = withDerivedProperties( - gatheringExchange(idAllocator.getNextId(), REMOTE, child.getNode()), - child.getProperties()); + PlanWithProperties child; + switch (node.getStep()) { + case SINGLE: + case FINAL: + child = planChild(node, context.withPreferredProperties(PreferredProperties.undistributed())); + if (!child.getProperties().isSingleNode()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE, child.getNode()), + child.getProperties()); + } + break; + case PARTIAL: + child = planChild(node, context.withPreferredProperties(PreferredProperties.any())); + break; + default: + throw new UnsupportedOperationException(format("Unsupported step for TopN [%s]", node.getStep())); } - return rebaseAndDeriveProperties(node, child); } @@ -508,7 +518,7 @@ public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, Context con gatheringExchange( idAllocator.getNextId(), REMOTE, - new DistinctLimitNode(idAllocator.getNextId(), child.getNode(), node.getLimit(), true, node.getHashSymbol())), + new DistinctLimitNode(idAllocator.getNextId(), child.getNode(), node.getLimit(), true, node.getDistinctSymbols(), node.getHashSymbol())), child.getProperties()); } @@ -652,7 +662,7 @@ private PlanWithProperties pickPlan(List possiblePlans, Cont private boolean shouldPrune(Expression predicate, Map assignments, Map bindings, List correlations) { List conjuncts = extractConjuncts(predicate); - IdentityLinkedHashMap expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, parser, @@ -664,7 +674,7 @@ private boolean shouldPrune(Expression predicate, Map assi // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true and so the partition should be pruned for (Expression expression : conjuncts) { - if (DependencyExtractor.extractUnique(expression).stream().anyMatch(correlations::contains)) { + if (SymbolsExtractor.extractUnique(expression).stream().anyMatch(correlations::contains)) { // expression contains correlated symbol with outer query continue; } @@ -859,7 +869,7 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, Context context) if (source.getProperties().isNodePartitionedOn(sourceSymbols) && !source.getProperties().isSingleNode()) { Partitioning filteringPartitioning = source.getProperties().translate(createTranslator(sourceToFiltering)).getNodePartitioning().get(); - filteringSource = node.getFilteringSource().accept(this, context.withPreferredProperties(PreferredProperties.partitionedWithNullsReplicated(filteringPartitioning))); + filteringSource = node.getFilteringSource().accept(this, context.withPreferredProperties(PreferredProperties.partitionedWithNullsAndAnyReplicated(filteringPartitioning))); if (!source.getProperties().withReplicatedNulls(true).isNodePartitionedWith(filteringSource.getProperties(), sourceToFiltering::get)) { filteringSource = withDerivedProperties( partitionedExchange(idAllocator.getNextId(), REMOTE, filteringSource.getNode(), new PartitioningScheme( @@ -872,7 +882,7 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, Context context) } } else { - filteringSource = node.getFilteringSource().accept(this, context.withPreferredProperties(PreferredProperties.partitionedWithNullsReplicated(ImmutableSet.copyOf(filteringSourceSymbols)))); + filteringSource = node.getFilteringSource().accept(this, context.withPreferredProperties(PreferredProperties.partitionedWithNullsAndAnyReplicated(ImmutableSet.copyOf(filteringSourceSymbols)))); if (filteringSource.getProperties().isNodePartitionedOn(filteringSourceSymbols, true) && !filteringSource.getProperties().isSingleNode()) { Partitioning sourcePartitioning = filteringSource.getProperties().translate(createTranslator(filteringToSource)).getNodePartitioning().get(); @@ -1015,14 +1025,14 @@ private Partitioning selectUnionPartitioning(UnionNode node, Context context, Pr } // Try planning the children to see if any of them naturally produce a partitioning (for now, just select the first) - boolean nullsReplicated = parentPreference.isNullsReplicated(); + boolean nullsAndAnyReplicated = parentPreference.isNullsAndAnyReplicated(); for (int sourceIndex = 0; sourceIndex < node.getSources().size(); sourceIndex++) { PreferredProperties.PartitioningProperties childPartitioning = parentPreference.translate(outputToInputTranslator(node, sourceIndex)).get(); PreferredProperties childPreferred = PreferredProperties.builder() - .global(PreferredProperties.Global.distributed(childPartitioning.withNullsReplicated(nullsReplicated))) + .global(PreferredProperties.Global.distributed(childPartitioning.withNullsAndAnyReplicated(nullsAndAnyReplicated))) .build(); PlanWithProperties child = node.getSources().get(sourceIndex).accept(this, context.withPreferredProperties(childPreferred)); - if (child.getProperties().isNodePartitionedOn(childPartitioning.getPartitioningColumns(), nullsReplicated)) { + if (child.getProperties().isNodePartitionedOn(childPartitioning.getPartitioningColumns(), nullsAndAnyReplicated)) { Function> childToParent = createTranslator(createMapping(node.sourceOutputLayout(sourceIndex), node.getOutputSymbols())); return child.getProperties().translate(childToParent).getNodePartitioning().get(); } @@ -1039,7 +1049,7 @@ public PlanWithProperties visitUnion(UnionNode node, Context context) Optional parentGlobal = parentPreference.getGlobalProperties(); if (parentGlobal.isPresent() && parentGlobal.get().isDistributed() && parentGlobal.get().getPartitioningProperties().isPresent()) { PreferredProperties.PartitioningProperties parentPartitioningPreference = parentGlobal.get().getPartitioningProperties().get(); - boolean nullsReplicated = parentPartitioningPreference.isNullsReplicated(); + boolean nullsAndAnyReplicated = parentPartitioningPreference.isNullsAndAnyReplicated(); Partitioning desiredParentPartitioning = selectUnionPartitioning(node, context, parentPartitioningPreference); ImmutableList.Builder partitionedSources = ImmutableList.builder(); @@ -1050,11 +1060,11 @@ public PlanWithProperties visitUnion(UnionNode node, Context context) PreferredProperties childPreferred = PreferredProperties.builder() .global(PreferredProperties.Global.distributed(PreferredProperties.PartitioningProperties.partitioned(childPartitioning) - .withNullsReplicated(nullsReplicated))) + .withNullsAndAnyReplicated(nullsAndAnyReplicated))) .build(); PlanWithProperties source = node.getSources().get(sourceIndex).accept(this, context.withPreferredProperties(childPreferred)); - if (!source.getProperties().isNodePartitionedOn(childPartitioning, nullsReplicated)) { + if (!source.getProperties().isNodePartitionedOn(childPartitioning, nullsAndAnyReplicated)) { source = withDerivedProperties( partitionedExchange( idAllocator.getNextId(), @@ -1064,7 +1074,7 @@ public PlanWithProperties visitUnion(UnionNode node, Context context) childPartitioning, source.getNode().getOutputSymbols(), Optional.empty(), - nullsReplicated, + nullsAndAnyReplicated, Optional.empty())), source.getProperties()); } @@ -1085,7 +1095,7 @@ public PlanWithProperties visitUnion(UnionNode node, Context context) ActualProperties.builder() .global(partitionedOn(desiredParentPartitioning, Optional.of(desiredParentPartitioning))) .build() - .withReplicatedNulls(parentPartitioningPreference.isNullsReplicated())); + .withReplicatedNulls(parentPartitioningPreference.isNullsAndAnyReplicated())); } // first, classify children into partitioned and unpartitioned @@ -1216,6 +1226,21 @@ public PlanWithProperties visitApply(ApplyNode node, Context context) return new PlanWithProperties(rewritten, deriveProperties(rewritten, ImmutableList.of(input.getProperties(), subquery.getProperties()))); } + @Override + public PlanWithProperties visitLateralJoin(LateralJoinNode node, Context context) + { + PlanWithProperties input = node.getInput().accept(this, context); + PlanWithProperties subquery = node.getSubquery().accept(this, context.withCorrelations(node.getCorrelation())); + + LateralJoinNode rewritten = new LateralJoinNode( + node.getId(), + input.getNode(), + subquery.getNode(), + node.getCorrelation(), + node.getType()); + return new PlanWithProperties(rewritten, deriveProperties(rewritten, ImmutableList.of(input.getProperties(), subquery.getProperties()))); + } + private PlanWithProperties planChild(PlanNode node, Context context) { return getOnlyElement(node.getSources()).accept(this, context); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index 2ca4b73398018..7dc8e3ec06313 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -104,7 +104,7 @@ public PlanNode optimize(PlanNode plan, Session session, Map types } private class Rewriter - extends PlanVisitor + extends PlanVisitor { private final PlanNodeIdAllocator idAllocator; private final Session session; @@ -169,7 +169,7 @@ public PlanWithProperties visitTableFinish(TableFinishNode node, StreamPreferred @Override public PlanWithProperties visitTopN(TopNNode node, StreamPreferredProperties parentPreferences) { - if (node.isPartial()) { + if (node.getStep().equals(TopNNode.Step.PARTIAL)) { return planAndEnforceChildren( node, parentPreferences.withoutPreference().withDefaultParallelism(session), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CountConstantOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CountConstantOptimizer.java index f3d59bb6279a6..7fde4049bc86b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CountConstantOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CountConstantOptimizer.java @@ -21,6 +21,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; @@ -61,19 +62,20 @@ private static class Rewriter @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { - Map aggregations = new LinkedHashMap<>(node.getAggregations()); - Map functions = new LinkedHashMap<>(node.getFunctions()); + Map aggregations = new LinkedHashMap<>(node.getAggregations()); PlanNode source = context.rewrite(node.getSource()); if (source instanceof ProjectNode) { ProjectNode projectNode = (ProjectNode) source; - for (Entry entry : node.getAggregations().entrySet()) { + for (Entry entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); - FunctionCall functionCall = entry.getValue(); - Signature signature = node.getFunctions().get(symbol); - if (isCountConstant(projectNode, functionCall, signature)) { - aggregations.put(symbol, new FunctionCall(functionCall.getName(), functionCall.getWindow(), functionCall.getFilter(), functionCall.isDistinct(), ImmutableList.of())); - functions.put(symbol, new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT))); + Aggregation aggregation = entry.getValue(); + FunctionCall functionCall = aggregation.getCall(); + if (isCountConstant(projectNode, functionCall, aggregation.getSignature())) { + aggregations.put(symbol, new Aggregation( + new FunctionCall(functionCall.getName(), functionCall.getWindow(), functionCall.getFilter(), functionCall.isDistinct(), ImmutableList.of()), + new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT)), + aggregation.getMask())); } } } @@ -82,8 +84,6 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont node.getId(), source, aggregations, - functions, - node.getMasks(), node.getGroupingSets(), node.getStep(), node.getHashSymbol(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DesugaringOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DesugaringOptimizer.java index 17ccd453c0e49..1b08faf1311c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DesugaringOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DesugaringOptimizer.java @@ -36,8 +36,9 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GroupingOperation; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.SymbolReference; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import java.util.Map; import java.util.Optional; @@ -102,7 +103,7 @@ public PlanNode visitPlan(PlanNode node, RewriteContext context) public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - Map assignments = node.getAssignments().entrySet().stream() + Map aggregations = node.getAggregations().entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> { Aggregation aggregation = entry.getValue(); return new Aggregation((FunctionCall) desugar(aggregation.getCall()), aggregation.getSignature(), aggregation.getMask()); @@ -110,7 +111,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont return new AggregationNode( node.getId(), source, - assignments, + aggregations, node.getGroupingSets(), node.getStep(), node.getHashSymbol(), @@ -194,10 +195,12 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) private Expression desugar(Expression expression) { + checkState(!(expression instanceof GroupingOperation), "GroupingOperation should have been re-written to a FunctionCall before execution"); + if (expression instanceof SymbolReference) { return expression; } - IdentityLinkedHashMap expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList() /* parameters already replaced */); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList() /* parameters already replaced */); expression = new LambdaCaptureDesugaringRewriter(types, symbolAllocator).rewrite(expression); expression = ExpressionTreeRewriter.rewriteWith(new DesugaringRewriter(expressionTypes), expression); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DistinctOutputQueryUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DistinctOutputQueryUtil.java new file mode 100644 index 0000000000000..41c1bef03c787 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DistinctOutputQueryUtil.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.DistinctLimitNode; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; +import com.facebook.presto.sql.planner.plan.ExceptNode; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.IntersectNode; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanVisitor; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.TopNNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; + +import java.util.function.Function; + +import static java.util.function.Function.identity; + +public final class DistinctOutputQueryUtil +{ + private DistinctOutputQueryUtil() {} + + public static boolean isDistinct(PlanNode node) + { + return node.accept(new IsDistinctPlanVisitor(identity()), null); + } + + public static boolean isDistinct(PlanNode node, Function lookupFunction) + { + return node.accept(new IsDistinctPlanVisitor(lookupFunction), null); + } + + private static final class IsDistinctPlanVisitor + extends PlanVisitor + { + /* + With the iterative optimizer, plan nodes are replaced with + GroupReference nodes. This requires the rules that look deeper + in the tree than the rewritten node and its immediate sources + to use Lookup for resolving the nodes corresponding to the + GroupReferences. + */ + private final Function lookupFunction; + + private IsDistinctPlanVisitor(Function lookupFunction) + { + this.lookupFunction = lookupFunction; + } + + @Override + protected Boolean visitPlan(PlanNode node, Void context) + { + return false; + } + + @Override + public Boolean visitAggregation(AggregationNode node, Void context) + { + return true; + } + + @Override + public Boolean visitAssignUniqueId(AssignUniqueId node, Void context) + { + return true; + } + + @Override + public Boolean visitDistinctLimit(DistinctLimitNode node, Void context) + { + return true; + } + + @Override + public Boolean visitEnforceSingleRow(EnforceSingleRowNode node, Void context) + { + return true; + } + + @Override + public Boolean visitExcept(ExceptNode node, Void context) + { + return true; + } + + @Override + public Boolean visitFilter(FilterNode node, Void context) + { + return lookupFunction.apply(node.getSource()).accept(this, null); + } + + @Override + public Boolean visitIntersect(IntersectNode node, Void context) + { + return true; + } + + @Override + public Boolean visitProject(ProjectNode node, Void context) + { + return node.isIdentity() && lookupFunction.apply(node.getSource()).accept(this, null); + } + + @Override + public Boolean visitValues(ValuesNode node, Void context) + { + return node.getRows().size() == 1; + } + + @Override + public Boolean visitLimit(LimitNode node, Void context) + { + return node.getCount() <= 1; + } + + @Override + public Boolean visitTopN(TopNNode node, Void context) + { + return node.getCount() <= 1; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java index 1309fe919bb4b..89a4eccfc674c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java @@ -20,30 +20,22 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; -import com.facebook.presto.sql.planner.plan.Assignments; -import com.facebook.presto.sql.planner.plan.FilterNode; -import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; -import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.PriorityQueue; -import java.util.Set; +import java.util.Objects; +import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.buildJoinTree; +import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; +import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; +@Deprecated public class EliminateCrossJoins implements PlanOptimizer { @@ -74,66 +66,6 @@ public PlanNode optimize( return plan; } - public static boolean isOriginalOrder(List joinOrder) - { - for (int i = 0; i < joinOrder.size(); i++) { - if (joinOrder.get(i) != i) { - return false; - } - } - return true; - } - - /** - * Given JoinGraph determine the order of joins between graph nodes - * by traversing JoinGraph. Any graph traversal algorithm could be used - * here (like BFS or DFS), but we use PriorityQueue to preserve - * original JoinOrder as mush as it is possible. PriorityQueue returns - * next nodes to join in order of their occurrence in original Plan. - */ - public static List getJoinOrder(JoinGraph graph) - { - ImmutableList.Builder joinOrder = ImmutableList.builder(); - - Map priorities = new HashMap<>(); - for (int i = 0; i < graph.size(); i++) { - priorities.put(graph.getNode(i).getId(), i); - } - - PriorityQueue nodesToVisit = new PriorityQueue<>( - graph.size(), - (Comparator) (node1, node2) -> priorities.get(node1.getId()).compareTo(priorities.get(node2.getId()))); - Set visited = new HashSet<>(); - - nodesToVisit.add(graph.getNode(0)); - - while (!nodesToVisit.isEmpty()) { - PlanNode node = nodesToVisit.poll(); - if (!visited.contains(node)) { - visited.add(node); - joinOrder.add(node); - for (JoinGraph.Edge edge : graph.getEdges(node)) { - nodesToVisit.add(edge.getTargetNode()); - } - } - - if (nodesToVisit.isEmpty() && visited.size() < graph.size()) { - // disconnected graph, find new starting point - Optional firstNotVisitedNode = graph.getNodes().stream() - .filter(graphNode -> !visited.contains(graphNode)) - .findFirst(); - if (firstNotVisitedNode.isPresent()) { - nodesToVisit.add(firstNotVisitedNode.get()); - } - } - } - - checkState(visited.size() == graph.size()); - return joinOrder.build().stream() - .map(node -> priorities.get(node.getId())) - .collect(toImmutableList()); - } - private class Rewriter extends SimplePlanRewriter { @@ -145,78 +77,18 @@ public Rewriter(PlanNodeIdAllocator idAllocator, JoinGraph graph, List { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.graph = requireNonNull(graph, "graph is null"); - this.joinOrder = requireNonNull(joinOrder, "joinOrder is null"); + this.joinOrder = ImmutableList.copyOf(requireNonNull(joinOrder, "joinOrder is null")); checkState(joinOrder.size() >= 2); } @Override public PlanNode visitPlan(PlanNode node, RewriteContext context) { - if (node.getId() != graph.getRootId()) { + if (!Objects.equals(node.getId(), graph.getRootId())) { return context.defaultRewrite(node, context.get()); } - PlanNode result = graph.getNode(joinOrder.get(0)); - Set alreadyJoinedNodes = new HashSet<>(); - alreadyJoinedNodes.add(result.getId()); - - for (int i = 1; i < joinOrder.size(); i++) { - PlanNode rightNode = graph.getNode(joinOrder.get(i)); - alreadyJoinedNodes.add(rightNode.getId()); - - ImmutableList.Builder criteria = ImmutableList.builder(); - - for (JoinGraph.Edge edge : graph.getEdges(rightNode)) { - PlanNode targetNode = edge.getTargetNode(); - if (alreadyJoinedNodes.contains(targetNode.getId())) { - criteria.add(new JoinNode.EquiJoinClause( - edge.getTargetSymbol(), - edge.getSourceSymbol())); - } - } - - result = new JoinNode( - idAllocator.getNextId(), - JoinNode.Type.INNER, - result, - rightNode, - criteria.build(), - ImmutableList.builder() - .addAll(result.getOutputSymbols()) - .addAll(rightNode.getOutputSymbols()) - .build(), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()); - } - - List filters = graph.getFilters(); - - for (Expression filter : filters) { - result = new FilterNode( - idAllocator.getNextId(), - result, - filter); - } - - if (graph.getAssignments().isPresent()) { - result = new ProjectNode( - idAllocator.getNextId(), - result, - Assignments.copyOf(graph.getAssignments().get())); - } - - if (!result.getOutputSymbols().equals(node.getOutputSymbols())) { - // Introduce a projection to constrain the outputs to what was originally expected - // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) - result = new ProjectNode( - idAllocator.getNextId(), - result, - Assignments.identity(node.getOutputSymbols())); - } - - return result; + return buildJoinTree(node.getOutputSymbols(), graph, joinOrder, idAllocator); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java index c699db0c1f009..061713624a390 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java @@ -28,7 +28,7 @@ import com.facebook.presto.sql.relational.RowExpressionVisitor; import com.facebook.presto.sql.relational.VariableReferenceExpression; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.ComparisonChain; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -100,7 +100,7 @@ private RowExpression toRowExpression(Session session, Expression expression, Ma Expression expressionWithInputReferences = new SymbolToInputRewriter(symbolInput).rewrite(expression); // determine the type of every expression - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypesFromInput( session, metadata, sqlParser, @@ -113,7 +113,7 @@ private RowExpression toRowExpression(Session session, Expression expression, Ma } private static class CanonicalizationVisitor - implements RowExpressionVisitor + implements RowExpressionVisitor { @Override public RowExpression visitCall(CallExpression call, Void context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 60d926e41d0f2..9743293359793 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -33,6 +33,7 @@ import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode.EquiJoinClause; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanVisitor; @@ -112,7 +113,7 @@ public PlanNode optimize(PlanNode plan, Session session, Map types } private static class Rewriter - extends PlanVisitor + extends PlanVisitor { private final PlanNodeIdAllocator idAllocator; private final SymbolAllocator symbolAllocator; @@ -146,6 +147,14 @@ public PlanWithProperties visitApply(ApplyNode node, HashComputationSet context) return new PlanWithProperties(node, ImmutableMap.of()); } + @Override + public PlanWithProperties visitLateralJoin(LateralJoinNode node, HashComputationSet context) + { + // Lateral join node is not supported by execution, so do not rewrite it + // that way query will fail in sanity checkers + return new PlanWithProperties(node, ImmutableMap.of()); + } + @Override public PlanWithProperties visitAggregation(AggregationNode node, HashComputationSet parentPreference) { @@ -165,8 +174,6 @@ public PlanWithProperties visitAggregation(AggregationNode node, HashComputation idAllocator.getNextId(), child.getNode(), node.getAggregations(), - node.getFunctions(), - node.getMasks(), node.getGroupingSets(), node.getStep(), hashSymbol, @@ -204,7 +211,7 @@ public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, HashComputa Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get()); return new PlanWithProperties( - new DistinctLimitNode(idAllocator.getNextId(), child.getNode(), node.getLimit(), node.isPartial(), Optional.of(hashSymbol)), + new DistinctLimitNode(idAllocator.getNextId(), child.getNode(), node.getLimit(), node.isPartial(), node.getDistinctSymbols(), Optional.of(hashSymbol)), child.getHashSymbols()); } @@ -488,7 +495,7 @@ public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet pa .collect(toImmutableList())) .build(), partitionSymbols.map(newHashSymbols::get), - partitioningScheme.isReplicateNulls(), + partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition()); // add hash symbols to sources diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index 153d37dc11903..175d58919cf7d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -22,6 +22,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExceptNode; import com.facebook.presto.sql.planner.plan.FilterNode; @@ -226,20 +227,19 @@ private UnionNode union(List nodes, List outputs) } private AggregationNode computeCounts(UnionNode sourceNode, List originalColumns, List markers, List aggregationOutputs) { - ImmutableMap.Builder signatures = ImmutableMap.builder(); - ImmutableMap.Builder aggregations = ImmutableMap.builder(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); for (int i = 0; i < markers.size(); i++) { Symbol output = aggregationOutputs.get(i); - aggregations.put(output, new FunctionCall(QualifiedName.of("count"), ImmutableList.of(markers.get(i).toSymbolReference()))); - signatures.put(output, COUNT_AGGREGATION); + aggregations.put(output, new Aggregation( + new FunctionCall(QualifiedName.of("count"), ImmutableList.of(markers.get(i).toSymbolReference())), + COUNT_AGGREGATION, + Optional.empty())); } return new AggregationNode(idAllocator.getNextId(), sourceNode, aggregations.build(), - signatures.build(), - ImmutableMap.of(), ImmutableList.of(originalColumns), Step.SINGLE, Optional.empty(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java index ddaf8782b3a6b..ad0951bc3a8c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java @@ -469,7 +469,7 @@ public static Map trace(PlanNode node, Set lookupSymbols } private static class Visitor - extends PlanVisitor, Map> + extends PlanVisitor, Set> { @Override protected Map visitPlan(PlanNode node, Set lookupSymbols) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java index 715910bf39a59..2df062a9030b2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/LimitPushDown.java @@ -139,7 +139,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) if (limit != null) { count = Math.min(count, limit.getCount()); } - return new TopNNode(node.getId(), rewrittenSource, count, node.getOrderBy(), node.getOrderings(), node.isPartial()); + return new TopNNode(node.getId(), rewrittenSource, count, node.getOrderBy(), node.getOrderings(), node.getStep()); } @Override @@ -190,7 +190,7 @@ public PlanNode visitSort(SortNode node, RewriteContext context) PlanNode rewrittenSource = context.rewrite(node.getSource()); if (limit != null) { - return new TopNNode(node.getId(), rewrittenSource, limit.getCount(), node.getOrderBy(), node.getOrderings(), false); + return new TopNNode(node.getId(), rewrittenSource, limit.getCount(), node.getOrderBy(), node.getOrderings(), TopNNode.Step.SINGLE); } else if (rewrittenSource != node.getSource()) { return new SortNode(node.getId(), rewrittenSource, node.getOrderBy(), node.getOrderings()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeProjections.java deleted file mode 100644 index 509524e34f159..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeProjections.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.planner.optimizations; - -import com.facebook.presto.Session; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.DeterminismEvaluator; -import com.facebook.presto.sql.planner.ExpressionSymbolInliner; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.plan.Assignments; -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.TryExpression; -import com.facebook.presto.sql.util.AstUtils; -import com.google.common.collect.ImmutableList; - -import java.util.Map; - -import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; -import static java.util.Objects.requireNonNull; - -/** - * Merges chains of consecutive projections - */ -@Deprecated -public class MergeProjections - implements PlanOptimizer -{ - @Override - public PlanNode optimize(PlanNode plan, Session session, Map types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) - { - requireNonNull(plan, "plan is null"); - requireNonNull(session, "session is null"); - requireNonNull(types, "types is null"); - requireNonNull(symbolAllocator, "symbolAllocator is null"); - requireNonNull(idAllocator, "idAllocator is null"); - - return SimplePlanRewriter.rewriteWith(new Rewriter(), plan); - } - - private static class Rewriter - extends SimplePlanRewriter - { - @Override - public PlanNode visitProject(ProjectNode node, RewriteContext context) - { - PlanNode source = context.rewrite(node.getSource()); - - if (source instanceof ProjectNode) { - ProjectNode sourceProject = (ProjectNode) source; - if (isDeterministic(sourceProject) && !containsTry(node)) { - Assignments.Builder projections = Assignments.builder(); - for (Map.Entry projection : node.getAssignments().entrySet()) { - Expression inlined = new ExpressionSymbolInliner(sourceProject.getAssignments().getMap()).rewrite(projection.getValue()); - projections.put(projection.getKey(), inlined); - } - - return new ProjectNode(node.getId(), sourceProject.getSource(), projections.build()); - } - } - return replaceChildren(node, ImmutableList.of(source)); - } - - private static boolean isDeterministic(ProjectNode node) - { - return node.getAssignments().getExpressions().stream().allMatch(DeterminismEvaluator::isDeterministic); - } - - private static boolean containsTry(ProjectNode node) - { - return node.getAssignments() - .getExpressions().stream() - .flatMap(AstUtils::preOrder) - .anyMatch(TryExpression.class::isInstance); - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeWindows.java deleted file mode 100644 index 9845100d708ef..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeWindows.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.planner.optimizations; - -import com.facebook.presto.Session; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.DependencyExtractor; -import com.facebook.presto.sql.planner.PlanNodeIdAllocator; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.planner.plan.WindowNode; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.Multimap; - -import java.util.Collection; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static com.google.common.base.Preconditions.checkState; - -/** - * Merge together the functions in WindowNodes that have identical WindowNode.Specifications. - * For example: - *

- * OutputNode - * `--... - * `--WindowNode(Specification: A, Functions: [sum(something)]) - * `--WindowNode(Specification: B, Functions: [sum(something)]) - * `--WindowNode(Specification: A, Functions: [avg(something)]) - * `--... - * - * Will be transformed into - *

- * OutputNode - * `--... - * `--WindowNode(Specification: B, Functions: [sum(something)]) - * `--WindowNode(Specification: A, Functions: [avg(something), sum(something)]) - * `--... - * - * This will NOT merge the functions in WindowNodes that have identical WindowNode.Specifications, - * but have a node between them that is not a WindowNode. - * In the following example, the functions in the WindowNodes with specification `A' will not be - * merged into a single WindowNode. - *

- * OutputNode - * `--... - * `--WindowNode(Specification: A, Functions: [sum(something)]) - * `--WindowNode(Specification: B, Functions: [sum(something)]) - * `-- ProjectNode(...) - * `--WindowNode(Specification: A, Functions: [avg(something)]) - * `--... - */ -public class MergeWindows - implements PlanOptimizer -{ - @Override - public PlanNode optimize(PlanNode plan, - Session session, - Map types, - SymbolAllocator symbolAllocator, - PlanNodeIdAllocator idAllocator) - { - // ImmutableListMultimap preserves order of window nodes - return SimplePlanRewriter.rewriteWith(new Rewriter(), plan, ImmutableListMultimap.of()); - } - - private static class Rewriter - extends SimplePlanRewriter> - { - @Override - protected PlanNode visitPlan( - PlanNode node, - RewriteContext> context) - { - PlanNode newNode = context.defaultRewrite(node, ImmutableListMultimap.of()); - return collapseWindowsWithinSpecification(context.get(), newNode); - } - - @Override - public PlanNode visitWindow( - WindowNode windowNode, - RewriteContext> context) - { - checkState(!windowNode.getHashSymbol().isPresent(), "MergeWindows should be run before HashGenerationOptimizer"); - checkState(windowNode.getPrePartitionedInputs().isEmpty() && windowNode.getPreSortedOrderPrefix() == 0, "MergeWindows should be run before AddExchanges"); - checkState(windowNode.getWindowFunctions().values().stream().distinct().count() == 1, "Frames expected to be identical"); - - for (WindowNode.Specification specification : context.get().keySet()) { - Collection nodes = context.get().get(specification); - if (nodes.stream().anyMatch(node -> dependsOn(node, windowNode))) { - return collapseWindowsWithinSpecification(context.get(), - context.rewrite( - windowNode.getSource(), - ImmutableListMultimap.of(windowNode.getSpecification(), windowNode))); - } - } - - return context.rewrite( - windowNode.getSource(), - ImmutableListMultimap.builder() - .put(windowNode.getSpecification(), windowNode) // Add the current window first so that it gets precedence in iteration order - .putAll(context.get()) - .build()); - } - - private static PlanNode collapseWindowsWithinSpecification(Multimap windowsMap, PlanNode sourceNode) - { - for (WindowNode.Specification specification : windowsMap.keySet()) { - Collection windows = windowsMap.get(specification); - sourceNode = collapseWindows(sourceNode, specification, windows); - } - return sourceNode; - } - - private static WindowNode collapseWindows(PlanNode source, WindowNode.Specification specification, Collection windows) - { - WindowNode canonical = windows.iterator().next(); - return new WindowNode( - canonical.getId(), - source, - specification, - windows.stream() - .map(WindowNode::getWindowFunctions) - .flatMap(map -> map.entrySet().stream()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)), - canonical.getHashSymbol(), - canonical.getPrePartitionedInputs(), - canonical.getPreSortedOrderPrefix()); - } - - private static boolean dependsOn(WindowNode parent, WindowNode child) - { - Set childOutputs = child.getCreatedSymbols(); - - Stream arguments = parent.getWindowFunctions().values().stream() - .map(WindowNode.Function::getFunctionCall) - .flatMap(functionCall -> functionCall.getArguments().stream()) - .map(DependencyExtractor::extractUnique) - .flatMap(Collection::stream); - - return parent.getPartitionBy().stream().anyMatch(childOutputs::contains) - || parent.getOrderBy().stream().anyMatch(childOutputs::contains) - || arguments.anyMatch(childOutputs::contains); - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java index ffa00f866cc13..4bd6534b9006a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -31,6 +31,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; @@ -42,7 +43,6 @@ import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -100,8 +100,8 @@ private Optimizer(Session session, Metadata metadata, PlanNodeIdAllocator idAllo public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { // supported functions are only MIN/MAX/APPROX_DISTINCT or distinct aggregates - for (FunctionCall call : node.getAggregations().values()) { - if (!ALLOWED_FUNCTIONS.contains(call.getName().toString()) && !call.isDistinct()) { + for (Aggregation aggregation : node.getAggregations().values()) { + if (!ALLOWED_FUNCTIONS.contains(aggregation.getCall().getName().toString()) && !aggregation.getCall().isDistinct()) { return context.defaultRewrite(node); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index affc18164c447..7b791e8965fe9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -23,6 +23,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; @@ -44,7 +45,6 @@ import com.google.common.collect.Iterables; import java.util.ArrayList; -import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -54,6 +54,7 @@ import static com.facebook.presto.SystemSessionProperties.isOptimizeDistinctAggregationEnabled; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; /* @@ -109,21 +110,22 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext masks = ImmutableSet.copyOf(node.getMasks().values()); - if (masks.size() != 1 || node.getMasks().size() == node.getAggregations().size()) { + List masks = node.getAggregations().values().stream() + .map(Aggregation::getMask).filter(Optional::isPresent).map(Optional::get).collect(toImmutableList()); + Set uniqueMasks = ImmutableSet.copyOf(masks); + if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) { return context.defaultRewrite(node, Optional.empty()); } - if (node.getAggregations().values().stream().map(FunctionCall::getFilter).anyMatch(Optional::isPresent)) { + if (node.getAggregations().values().stream().map(Aggregation::getCall).map(FunctionCall::getFilter).anyMatch(Optional::isPresent)) { // Skip if any aggregation contains a filter return context.defaultRewrite(node, Optional.empty()); } AggregateInfo aggregateInfo = new AggregateInfo( node.getGroupingKeys(), - Iterables.getOnlyElement(masks), - node.getAggregations(), - node.getFunctions()); + Iterables.getOnlyElement(uniqueMasks), + node.getAggregations()); if (!checkAllEquatableTypes(aggregateInfo)) { // This optimization relies on being able to GROUP BY arguments @@ -141,29 +143,27 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext aggregations = ImmutableMap.builder(); - ImmutableMap.Builder functions = ImmutableMap.builder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { - FunctionCall functionCall = entry.getValue(); - if (entry.getValue().isDistinct()) { - aggregations.put( - entry.getKey(), new FunctionCall( + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : node.getAggregations().entrySet()) { + FunctionCall functionCall = entry.getValue().getCall(); + if (functionCall.isDistinct()) { + aggregations.put(entry.getKey(), new Aggregation( + new FunctionCall( functionCall.getName(), functionCall.getWindow(), false, - ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference()))); - functions.put(entry.getKey(), node.getFunctions().get(entry.getKey())); + ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference())), + entry.getValue().getSignature(), + Optional.empty())); } else { // Aggregations on non-distinct are already done by new node, just extract the non-null value Symbol argument = aggregateInfo.getNewNonDistinctAggregateSymbols().get(entry.getKey()); QualifiedName functionName = QualifiedName.of("arbitrary"); - aggregations.put(entry.getKey(), new FunctionCall( - functionName, - functionCall.getWindow(), - false, - ImmutableList.of(argument.toSymbolReference()))); - functions.put(entry.getKey(), getFunctionSignature(functionName, argument)); + aggregations.put(entry.getKey(), new Aggregation( + new FunctionCall(functionName, functionCall.getWindow(), false, ImmutableList.of(argument.toSymbolReference())), + getFunctionSignature(functionName, argument), + Optional.empty())); } } @@ -171,8 +171,6 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext aggregationOutputSymbolsMapBuilder = ImmutableMap.builder(); AggregationNode aggregationNode = createNonDistinctAggregation( aggregateInfo.get(), distinctSymbol, @@ -384,25 +382,20 @@ private AggregationNode createNonDistinctAggregation( Set groupByKeys, GroupIdNode groupIdNode, MarkDistinctNode originalNode, - ImmutableMap.Builder aggregationOutputSymbolsMapBuilder + ImmutableMap.Builder aggregationOutputSymbolsMapBuilder ) { - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - ImmutableMap.Builder functions = ImmutableMap.builder(); - for (Map.Entry entry : aggregateInfo.getAggregations().entrySet()) { - FunctionCall functionCall = entry.getValue(); + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : aggregateInfo.getAggregations().entrySet()) { + FunctionCall functionCall = entry.getValue().getCall(); if (!functionCall.isDistinct()) { Symbol newSymbol = symbolAllocator.newSymbol(entry.getKey().toSymbolReference(), symbolAllocator.getTypes().get(entry.getKey())); aggregationOutputSymbolsMapBuilder.put(newSymbol, entry.getKey()); - if (duplicatedDistinctSymbol.equals(distinctSymbol)) { - // Mask symbol was not present in aggregations without mask - aggregations.put(newSymbol, functionCall); - } - else { + if (!duplicatedDistinctSymbol.equals(distinctSymbol)) { // Handling for cases when mask symbol appears in non distinct aggregations too // Now the aggregation should happen over the duplicate symbol added before if (functionCall.getArguments().contains(distinctSymbol.toSymbolReference())) { - ImmutableList.Builder arguments = ImmutableList.builder(); + ImmutableList.Builder arguments = ImmutableList.builder(); for (Expression argument : functionCall.getArguments()) { if (distinctSymbol.toSymbolReference().equals(argument)) { arguments.add(duplicatedDistinctSymbol.toSymbolReference()); @@ -411,21 +404,16 @@ private AggregationNode createNonDistinctAggregation( arguments.add(argument); } } - aggregations.put(newSymbol, new FunctionCall(functionCall.getName(), functionCall.getWindow(), false, arguments.build())); - } - else { - aggregations.put(newSymbol, functionCall); + functionCall = new FunctionCall(functionCall.getName(), functionCall.getWindow(), false, arguments.build()); } } - functions.put(newSymbol, aggregateInfo.getFunctions().get(entry.getKey())); + aggregations.put(newSymbol, new Aggregation(functionCall, entry.getValue().getSignature(), Optional.empty())); } } return new AggregationNode( idAllocator.getNextId(), groupIdNode, aggregations.build(), - functions.build(), - Collections.emptyMap(), ImmutableList.of(ImmutableList.copyOf(groupByKeys)), SINGLE, originalNode.getHashSymbol(), @@ -454,27 +442,26 @@ private static class AggregateInfo { private final List groupBySymbols; private final Symbol mask; - private final Map aggregations; - private final Map functions; + private final Map aggregations; // Filled on the way back, these are the symbols corresponding to their distinct or non-distinct original symbols private Map newNonDistinctAggregateSymbols; private Symbol newDistinctAggregateSymbol; private boolean foundMarkDistinct; - public AggregateInfo(List groupBySymbols, Symbol mask, Map aggregations, Map functions) + public AggregateInfo(List groupBySymbols, Symbol mask, Map aggregations) { this.groupBySymbols = ImmutableList.copyOf(groupBySymbols); this.mask = mask; this.aggregations = ImmutableMap.copyOf(aggregations); - this.functions = ImmutableMap.copyOf(functions); } public List getOriginalNonDistinctAggregateArgs() { return aggregations.values().stream() + .map(Aggregation::getCall) .filter(function -> !function.isDistinct()) .flatMap(function -> function.getArguments().stream()) .distinct() @@ -485,6 +472,7 @@ public List getOriginalNonDistinctAggregateArgs() public List getOriginalDistinctAggregateArgs() { return aggregations.values().stream() + .map(Aggregation::getCall) .filter(FunctionCall::isDistinct) .flatMap(function -> function.getArguments().stream()) .distinct() @@ -522,16 +510,11 @@ public List getGroupBySymbols() return groupBySymbols; } - public Map getAggregations() + public Map getAggregations() { return aggregations; } - public Map getFunctions() - { - return functions; - } - public void foundMarkDistinct() { foundMarkDistinct = true; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java index 55cc67b44e97b..06460719c061c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java @@ -23,9 +23,13 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; @@ -33,14 +37,21 @@ import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; +import static com.facebook.presto.SystemSessionProperties.isPushAggregationThroughJoin; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; @@ -48,6 +59,7 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class PartialAggregationPushDown @@ -65,7 +77,7 @@ public PartialAggregationPushDown(FunctionRegistry registry) @Override public PlanNode optimize(PlanNode plan, Session session, Map types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { - return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator), plan, null); + return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator, isPushAggregationThroughJoin(session)), plan, null); } private class Rewriter @@ -73,11 +85,13 @@ private class Rewriter { private final SymbolAllocator allocator; private final PlanNodeIdAllocator idAllocator; + private final boolean pushAggregationThroughJoin; - public Rewriter(SymbolAllocator allocator, PlanNodeIdAllocator idAllocator) + public Rewriter(SymbolAllocator allocator, PlanNodeIdAllocator idAllocator, boolean pushAggregationThroughJoin) { this.allocator = requireNonNull(allocator, "allocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.pushAggregationThroughJoin = pushAggregationThroughJoin; } @Override @@ -85,6 +99,10 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont { PlanNode child = node.getSource(); + if (child instanceof JoinNode && pushAggregationThroughJoin) { + return pushPartialThroughJoin(node, (JoinNode) child, context); + } + if (!(child instanceof ExchangeNode)) { return context.defaultRewrite(node); } @@ -108,7 +126,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont // the cardinality of the stream (i.e., gather or repartition) ExchangeNode exchange = (ExchangeNode) child; if ((exchange.getType() != GATHER && exchange.getType() != REPARTITION) || - exchange.getPartitioningScheme().isReplicateNulls()) { + exchange.getPartitioningScheme().isReplicateNullsAndAny()) { return context.defaultRewrite(node); } @@ -149,6 +167,129 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext cont } } + private PlanNode pushPartialThroughJoin(AggregationNode node, JoinNode child, RewriteContext context) + { + if (node.getStep() != PARTIAL || node.getGroupingSets().size() != 1) { + return context.defaultRewrite(node); + } + + if (child.getType() != JoinNode.Type.INNER || child.getFilter().isPresent()) { + // TODO: add support for filter function. + // All availableSymbols used in filter function could be added to pushedDownGroupingSet + return context.defaultRewrite(node); + } + + // TODO: leave partial aggregation above Join? + if (allAggregationsOn(node.getAggregations(), child.getLeft().getOutputSymbols())) { + return pushPartialToLeftChild(node, child, context); + } + else if (allAggregationsOn(node.getAggregations(), child.getRight().getOutputSymbols())) { + return pushPartialToRightChild(node, child, context); + } + + return context.defaultRewrite(node); + } + + private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, RewriteContext context) + { + List groupingSet = getPushedDownGroupingSet(node, child, ImmutableSet.copyOf(child.getLeft().getOutputSymbols())); + AggregationNode pushedAggregation = replaceAggregationSource(node, child.getLeft(), child.getCriteria(), groupingSet, context); + return pushPartialToJoin(pushedAggregation, child, pushedAggregation, context.rewrite(child.getRight()), child.getRight().getOutputSymbols()); + } + + private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, RewriteContext context) + { + List groupingSet = getPushedDownGroupingSet(node, child, ImmutableSet.copyOf(child.getRight().getOutputSymbols())); + AggregationNode pushedAggregation = replaceAggregationSource(node, child.getRight(), child.getCriteria(), groupingSet, context); + return pushPartialToJoin(pushedAggregation, child, context.rewrite(child.getLeft()), pushedAggregation, child.getLeft().getOutputSymbols()); + } + + private PlanNode pushPartialToJoin( + AggregationNode pushedAggregation, + JoinNode child, + PlanNode leftChild, + PlanNode rightChild, + Collection otherSymbols) + { + ImmutableList.Builder outputSymbols = ImmutableList.builder(); + outputSymbols.addAll(pushedAggregation.getOutputSymbols()); + outputSymbols.addAll(otherSymbols); + + return new JoinNode( + child.getId(), + child.getType(), + leftChild, + rightChild, + child.getCriteria(), + outputSymbols.build(), + child.getFilter(), + child.getLeftHashSymbol(), + child.getRightHashSymbol(), + child.getDistributionType()); + } + + private AggregationNode replaceAggregationSource( + AggregationNode aggregation, + PlanNode source, + List criteria, + List groupingSet, + RewriteContext context) + { + PlanNode rewrittenSource = context.rewrite(source); + ImmutableSet rewrittenSourceSymbols = ImmutableSet.copyOf(rewrittenSource.getOutputSymbols()); + ImmutableMap.Builder mapping = ImmutableMap.builder(); + + for (EquiJoinClause joinClause : criteria) { + if (rewrittenSourceSymbols.contains(joinClause.getLeft())) { + mapping.put(joinClause.getRight(), joinClause.getLeft()); + } + else { + mapping.put(joinClause.getLeft(), joinClause.getRight()); + } + } + + AggregationNode pushedAggregation = new AggregationNode( + aggregation.getId(), + aggregation.getSource(), + aggregation.getAggregations(), + ImmutableList.of(groupingSet), + aggregation.getStep(), + aggregation.getHashSymbol(), + aggregation.getGroupIdSymbol()); + return new SymbolMapper(mapping.build()).map(pushedAggregation, source); + } + + private boolean allAggregationsOn(Map aggregations, List outputSymbols) + { + Set inputs = SymbolsExtractor.extractUnique(aggregations.values().stream().map(Aggregation::getCall).collect(toImmutableList())); + return outputSymbols.containsAll(inputs); + } + + private List getPushedDownGroupingSet(AggregationNode aggregation, JoinNode join, Set availableSymbols) + { + List groupingSet = Iterables.getOnlyElement(aggregation.getGroupingSets()); + Set joinKeys = Stream.concat( + join.getCriteria().stream().map(EquiJoinClause::getLeft), + join.getCriteria().stream().map(EquiJoinClause::getRight) + ).collect(Collectors.toSet()); + + // keep symbols that are either directly from the join's child (availableSymbols) or there is + // an equality in join condition to a symbol for the join child + List pushedDownGroupingSet = groupingSet.stream() + .filter(symbol -> joinKeys.contains(symbol) || availableSymbols.contains(symbol)) + .collect(Collectors.toList()); + + if (pushedDownGroupingSet.size() != groupingSet.size() || pushedDownGroupingSet.isEmpty()) { + // If we dropped some symbol, we have to add all join key columns to the grouping set + Set existingSymbols = ImmutableSet.copyOf(pushedDownGroupingSet); + + join.getCriteria().stream() + .filter(equiJoinClause -> !existingSymbols.contains(equiJoinClause.getLeft()) && !existingSymbols.contains(equiJoinClause.getRight())) + .forEach(joinClause -> pushedDownGroupingSet.add(joinClause.getLeft())); + } + return pushedDownGroupingSet; + } + private PlanNode pushPartial(AggregationNode partial, ExchangeNode exchange) { List partials = new ArrayList<>(); @@ -186,7 +327,7 @@ private PlanNode pushPartial(AggregationNode partial, ExchangeNode exchange) exchange.getPartitioningScheme().getPartitioning(), partial.getOutputSymbols(), exchange.getPartitioningScheme().getHashColumn(), - exchange.getPartitioningScheme().isReplicateNulls(), + exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition()); return new ExchangeNode( @@ -201,33 +342,28 @@ private PlanNode pushPartial(AggregationNode partial, ExchangeNode exchange) private PlanNode split(AggregationNode node) { // otherwise, add a partial and final with an exchange in between - Map masks = node.getMasks(); - - Map finalCalls = new HashMap<>(); - Map intermediateCalls = new HashMap<>(); - Map intermediateFunctions = new HashMap<>(); - Map intermediateMask = new HashMap<>(); - for (Map.Entry entry : node.getAggregations().entrySet()) { - Signature signature = node.getFunctions().get(entry.getKey()); + Map intermediateAggregation = new HashMap<>(); + Map finalAggregation = new HashMap<>(); + for (Map.Entry entry : node.getAggregations().entrySet()) { + Aggregation originalAggregation = entry.getValue(); + Signature signature = originalAggregation.getSignature(); InternalAggregationFunction function = functionRegistry.getAggregateFunctionImplementation(signature); - Symbol intermediateSymbol = allocator.newSymbol(signature.getName(), function.getIntermediateType()); - intermediateCalls.put(intermediateSymbol, entry.getValue()); - intermediateFunctions.put(intermediateSymbol, signature); - if (masks.containsKey(entry.getKey())) { - intermediateMask.put(intermediateSymbol, masks.get(entry.getKey())); - } + + intermediateAggregation.put(intermediateSymbol, new Aggregation(originalAggregation.getCall(), signature, originalAggregation.getMask())); // rewrite final aggregation in terms of intermediate function - finalCalls.put(entry.getKey(), new FunctionCall(QualifiedName.of(signature.getName()), ImmutableList.of(intermediateSymbol.toSymbolReference()))); + finalAggregation.put(entry.getKey(), + new Aggregation( + new FunctionCall(QualifiedName.of(signature.getName()), ImmutableList.of(intermediateSymbol.toSymbolReference())), + signature, + Optional.empty())); } PlanNode partial = new AggregationNode( idAllocator.getNextId(), node.getSource(), - intermediateCalls, - intermediateFunctions, - intermediateMask, + intermediateAggregation, node.getGroupingSets(), PARTIAL, node.getHashSymbol(), @@ -236,9 +372,7 @@ private PlanNode split(AggregationNode node) return new AggregationNode( node.getId(), partial, - finalCalls, - node.getFunctions(), - ImmutableMap.of(), + finalAggregation, node.getGroupingSets(), FINAL, node.getHashSymbol(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java new file mode 100644 index 0000000000000..b2929434565d7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.sql.ExpressionUtils; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.facebook.presto.sql.tree.DefaultTraversalVisitor; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.LogicalBinaryExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class PlanNodeDecorrelator +{ + private final PlanNodeIdAllocator idAllocator; + private final Lookup lookup; + + public PlanNodeDecorrelator(PlanNodeIdAllocator idAllocator, Lookup lookup) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); + } + + public Optional decorrelateFilters(PlanNode node, List correlation) + { + PlanNodeSearcher filterNodeSearcher = searchFrom(node, lookup) + .where(FilterNode.class::isInstance) + .skipOnlyWhen(isInstanceOfAny(ProjectNode.class, LimitNode.class)); + List filterNodes = filterNodeSearcher.findAll(); + + if (filterNodes.isEmpty()) { + return decorrelatedNode(ImmutableList.of(), node, correlation); + } + + if (filterNodes.size() > 1) { + return Optional.empty(); + } + + FilterNode filterNode = filterNodes.get(0); + Expression predicate = filterNode.getPredicate(); + + if (!isSupportedPredicate(predicate)) { + return Optional.empty(); + } + + if (!SymbolsExtractor.extractUnique(predicate).containsAll(correlation)) { + return Optional.empty(); + } + + Map> predicates = ExpressionUtils.extractConjuncts(predicate).stream() + .collect(Collectors.partitioningBy(isUsingPredicate(correlation))); + List correlatedPredicates = ImmutableList.copyOf(predicates.get(true)); + List uncorrelatedPredicates = ImmutableList.copyOf(predicates.get(false)); + + node = updateFilterNode(filterNodeSearcher, uncorrelatedPredicates); + + if (!correlatedPredicates.isEmpty()) { + // filterNodes condition has changed so Limit node no longer applies for EXISTS subquery + node = removeLimitNode(node); + } + + node = ensureJoinSymbolsAreReturned(node, correlatedPredicates); + + return decorrelatedNode(correlatedPredicates, node, correlation); + } + + private static boolean isSupportedPredicate(Expression predicate) + { + AtomicBoolean isSupported = new AtomicBoolean(true); + new DefaultTraversalVisitor() + { + @Override + protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, AtomicBoolean context) + { + if (node.getType() != LogicalBinaryExpression.Type.AND) { + context.set(false); + } + return null; + } + }.process(predicate, isSupported); + return isSupported.get(); + } + + private Predicate isUsingPredicate(List symbols) + { + return expression -> symbols.stream().anyMatch(SymbolsExtractor.extractUnique(expression)::contains); + } + + private PlanNode updateFilterNode(PlanNodeSearcher filterNodeSearcher, List newPredicates) + { + if (newPredicates.isEmpty()) { + return filterNodeSearcher.removeAll(); + } + FilterNode oldFilterNode = Iterables.getOnlyElement(filterNodeSearcher.findAll()); + FilterNode newFilterNode = new FilterNode( + idAllocator.getNextId(), + oldFilterNode.getSource(), + ExpressionUtils.combineConjuncts(newPredicates)); + return filterNodeSearcher.replaceAll(newFilterNode); + } + + private PlanNode removeLimitNode(PlanNode node) + { + node = searchFrom(node, lookup) + .where(LimitNode.class::isInstance) + .skipOnlyWhen(ProjectNode.class::isInstance) + .removeFirst(); + return node; + } + + private PlanNode ensureJoinSymbolsAreReturned(PlanNode scalarAggregationSource, List joinPredicate) + { + Set joinExpressionSymbols = SymbolsExtractor.extractUnique(joinPredicate); + ExtendProjectionRewriter extendProjectionRewriter = new ExtendProjectionRewriter( + idAllocator, + joinExpressionSymbols); + return rewriteWith(extendProjectionRewriter, scalarAggregationSource); + } + + private Optional decorrelatedNode( + List correlatedPredicates, + PlanNode node, + List correlation) + { + if (SymbolsExtractor.extractUnique(node, lookup).stream().anyMatch(correlation::contains)) { + // node is still correlated ; / + return Optional.empty(); + } + return Optional.of(new DecorrelatedNode(correlatedPredicates, node)); + } + + public static class DecorrelatedNode + { + private final List correlatedPredicates; + private final PlanNode node; + + public DecorrelatedNode(List correlatedPredicates, PlanNode node) + { + requireNonNull(correlatedPredicates, "correlatedPredicates is null"); + this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates); + this.node = requireNonNull(node, "node is null"); + } + + Optional getCorrelatedPredicates() + { + if (correlatedPredicates.isEmpty()) { + return Optional.empty(); + } + return Optional.of(ExpressionUtils.and(correlatedPredicates)); + } + + public PlanNode getNode() + { + return node; + } + } + + private static class ExtendProjectionRewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + private final Set symbols; + + ExtendProjectionRewriter(PlanNodeIdAllocator idAllocator, Set symbols) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.symbols = requireNonNull(symbols, "symbols is null"); + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext context) + { + ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node, context.get()); + + List symbolsToAdd = symbols.stream() + .filter(rewrittenNode.getSource().getOutputSymbols()::contains) + .filter(symbol -> !rewrittenNode.getOutputSymbols().contains(symbol)) + .collect(toImmutableList()); + + Assignments assignments = Assignments.builder() + .putAll(rewrittenNode.getAssignments()) + .putIdentities(symbolsToAdd) + .build(); + + return new ProjectNode(idAllocator.getNextId(), rewrittenNode.getSource(), assignments); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java index 8e14fd6072677..7a0be1835c6f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeSearcher.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.optimizations; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.ImmutableList; @@ -20,26 +21,36 @@ import java.util.Optional; import java.util.function.Predicate; -import static com.facebook.presto.sql.planner.optimizations.Predicates.alwaysTrue; +import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; public class PlanNodeSearcher { + @Deprecated public static PlanNodeSearcher searchFrom(PlanNode node) { - return new PlanNodeSearcher(node); + return searchFrom(node, noLookup()); + } + + public static PlanNodeSearcher searchFrom(PlanNode node, Lookup lookup) + { + return new PlanNodeSearcher(node, lookup); } private final PlanNode node; + private final Lookup lookup; private Predicate where = alwaysTrue(); private Predicate skipOnly = alwaysTrue(); - public PlanNodeSearcher(PlanNode node) + public PlanNodeSearcher(PlanNode node, Lookup lookup) { this.node = requireNonNull(node, "node is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); } public PlanNodeSearcher where(Predicate where) @@ -61,6 +72,8 @@ public Optional findFirst() private Optional findFirstRecursive(PlanNode node) { + node = lookup.resolve(node); + if (where.test(node)) { return Optional.of((T) node); } @@ -75,6 +88,15 @@ private Optional findFirstRecursive(PlanNode node) return Optional.empty(); } + public Optional findSingle() + { + List all = findAll(); + if (all.size() == 1) { + return Optional.of(all.get(0)); + } + return Optional.empty(); + } + public List findAll() { ImmutableList.Builder nodes = ImmutableList.builder(); @@ -82,8 +104,24 @@ public List findAll() return nodes.build(); } + public T findOnlyElement() + { + return getOnlyElement(findAll()); + } + + public T findOnlyElement(T defaultValue) + { + List all = findAll(); + if (all.size() == 0) { + return defaultValue; + } + return getOnlyElement(all); + } + private void findAllRecursive(PlanNode node, ImmutableList.Builder nodes) { + node = lookup.resolve(node); + if (where.test(node)) { nodes.add((T) node); } @@ -101,6 +139,8 @@ public PlanNode removeAll() private PlanNode removeAllRecursive(PlanNode node) { + node = lookup.resolve(node); + if (where.test(node)) { checkArgument( node.getSources().size() == 1, @@ -123,6 +163,8 @@ public PlanNode removeFirst() private PlanNode removeFirstRecursive(PlanNode node) { + node = lookup.resolve(node); + if (where.test(node)) { checkArgument( node.getSources().size() == 1, @@ -151,6 +193,8 @@ public PlanNode replaceAll(PlanNode newPlanNode) private PlanNode replaceAllRecursive(PlanNode node, PlanNode nodeToReplace) { + node = lookup.resolve(node); + if (where.test(node)) { return nodeToReplace; } @@ -170,6 +214,8 @@ public PlanNode replaceFirst(PlanNode newPlanNode) private PlanNode replaceFirstRecursive(PlanNode node, PlanNode nodeToReplace) { + node = lookup.resolve(node); + if (where.test(node)) { return nodeToReplace; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index bff21db703ace..e726a6aace18f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -17,7 +17,6 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.EffectivePredicateExtractor; import com.facebook.presto.sql.planner.EqualityInference; @@ -28,6 +27,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; @@ -50,9 +50,9 @@ import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.SymbolReference; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; @@ -70,7 +70,6 @@ import java.util.stream.Collectors; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; -import static com.facebook.presto.sql.ExpressionUtils.expressionOrNullSymbols; import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; import static com.facebook.presto.sql.ExpressionUtils.stripNonDeterministicConjuncts; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; @@ -82,7 +81,6 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Predicates.equalTo; import static com.google.common.base.Predicates.in; import static com.google.common.base.Predicates.not; import static com.google.common.collect.Iterables.filter; @@ -196,7 +194,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex .map(Map.Entry::getKey) .collect(Collectors.toSet()); - Predicate deterministic = conjunct -> DependencyExtractor.extractUnique(conjunct).stream() + Predicate deterministic = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(deterministicSymbols::contains); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic)); @@ -216,13 +214,13 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) { - checkState(!DependencyExtractor.extractUnique(context.get()).contains(node.getGroupIdSymbol()), "groupId symbol cannot be referenced in predicate"); + checkState(!SymbolsExtractor.extractUnique(context.get()).contains(node.getGroupIdSymbol()), "groupId symbol cannot be referenced in predicate"); Map commonGroupingSymbolMapping = node.getGroupingSetMappings().entrySet().stream() .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); - Predicate pushdownEligiblePredicate = conjunct -> DependencyExtractor.extractUnique(conjunct).stream() + Predicate pushdownEligiblePredicate = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(commonGroupingSymbolMapping.keySet()::contains); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate)); @@ -242,7 +240,7 @@ public PlanNode visitGroupId(GroupIdNode node, RewriteContext contex @Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) { - checkState(!DependencyExtractor.extractUnique(context.get()).contains(node.getMarkerSymbol()), "predicate depends on marker symbol"); + checkState(!SymbolsExtractor.extractUnique(context.get()).contains(node.getMarkerSymbol()), "predicate depends on marker symbol"); return context.defaultRewrite(node, context.get()); } @@ -374,7 +372,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) { ComparisonExpression equality = (ComparisonExpression) conjunct; - boolean alignedComparison = Iterables.all(DependencyExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); + boolean alignedComparison = Iterables.all(SymbolsExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight(); Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft(); @@ -475,8 +473,8 @@ private static PlanNode createJoinNodeWithExpectedOutputs( private static OuterJoinPushDownResult processLimitedOuterJoin(Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection outerSymbols) { - checkArgument(Iterables.all(DependencyExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); - checkArgument(Iterables.all(DependencyExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); ImmutableList.Builder outerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder innerPushdownConjuncts = ImmutableList.builder(); @@ -596,8 +594,8 @@ private Expression getPostJoinPredicate() private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection leftSymbols) { - checkArgument(Iterables.all(DependencyExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); - checkArgument(Iterables.all(DependencyExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); + checkArgument(Iterables.all(SymbolsExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); ImmutableList.Builder leftPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder rightPushDownConjuncts = ImmutableList.builder(); @@ -717,22 +715,16 @@ private static Expression extractJoinPredicate(JoinNode joinNode) { ImmutableList.Builder builder = ImmutableList.builder(); for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) { - builder.add(equalsExpression(equiJoinClause.getLeft(), equiJoinClause.getRight())); + builder.add(equiJoinClause.toExpression()); } joinNode.getFilter().ifPresent(builder::add); return combineConjuncts(builder.build()); } - private static Expression equalsExpression(Symbol symbol1, Symbol symbol2) - { - return new ComparisonExpression(ComparisonExpressionType.EQUAL, - symbol1.toSymbolReference(), - symbol2.toSymbolReference()); - } - private Type extractType(Expression expression) { - return getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList() /* parameters have already been replaced */).get(expression); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList() /* parameters have already been replaced */); + return expressionTypes.get(NodeRef.of(expression)); } private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate) @@ -786,7 +778,7 @@ private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Ex // Temporary implementation for joins because the SimplifyExpressions optimizers can not run properly on join clauses private Expression simplifyExpression(Expression expression) { - IdentityLinkedHashMap expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, @@ -794,7 +786,7 @@ private Expression simplifyExpression(Expression expression) expression, emptyList() /* parameters have already been replaced */); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); - return LiteralInterpreter.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(expression)); + return LiteralInterpreter.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } /** @@ -802,7 +794,7 @@ private Expression simplifyExpression(Expression expression) */ private Object nullInputEvaluator(final Collection nullSymbols, Expression expression) { - IdentityLinkedHashMap expressionTypes = getExpressionTypes( + Map, Type> expressionTypes = getExpressionTypes( session, metadata, sqlParser, @@ -820,8 +812,8 @@ private static Predicate joinEqualityExpression(final Collection symbols1 = DependencyExtractor.extractUnique(comparison.getLeft()); - Set symbols2 = DependencyExtractor.extractUnique(comparison.getRight()); + Set symbols1 = SymbolsExtractor.extractUnique(comparison.getLeft()); + Set symbols2 = SymbolsExtractor.extractUnique(comparison.getRight()); if (symbols1.isEmpty() || symbols2.isEmpty()) { return false; } @@ -838,30 +830,12 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext cont { Expression inheritedPredicate = context.get(); - Expression sourceEffectivePredicate = EffectivePredicateExtractor.extract(node.getSource(), symbolAllocator.getTypes()); - List sourceConjuncts = new ArrayList<>(); - List filteringSourceConjuncts = new ArrayList<>(); List postJoinConjuncts = new ArrayList<>(); // TODO: see if there are predicates that can be inferred from the semi join output - // Push inherited and source predicates to filtering source via a contrived join predicate (but needs to avoid touching NULL values in the filtering source) - Expression joinPredicate = equalsExpression(node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol()); - EqualityInference joinInference = createEqualityInference(inheritedPredicate, sourceEffectivePredicate, joinPredicate); - for (Expression conjunct : Iterables.concat(EqualityInference.nonInferrableConjuncts(inheritedPredicate), EqualityInference.nonInferrableConjuncts(sourceEffectivePredicate))) { - Expression rewrittenConjunct = joinInference.rewriteExpression(conjunct, equalTo(node.getFilteringSourceJoinSymbol())); - if (rewrittenConjunct != null && DeterminismEvaluator.isDeterministic(rewrittenConjunct)) { - // Alter conjunct to include an OR filteringSourceJoinSymbol IS NULL disjunct - Expression rewrittenConjunctOrNull = expressionOrNullSymbols(Predicate.isEqual(node.getFilteringSourceJoinSymbol())).apply(rewrittenConjunct); - filteringSourceConjuncts.add(rewrittenConjunctOrNull); - } - } - EqualityInference.EqualityPartition joinInferenceEqualityPartition = joinInference.generateEqualitiesPartitionedBy(equalTo(node.getFilteringSourceJoinSymbol())); - - filteringSourceConjuncts.addAll(joinInferenceEqualityPartition.getScopeEqualities().stream() - .map(expressionOrNullSymbols(Predicate.isEqual(node.getFilteringSourceJoinSymbol()))) - .collect(Collectors.toList())); + PlanNode rewrittenFilteringSource = context.defaultRewrite(node.getFilteringSource(), BooleanLiteral.TRUE_LITERAL); // Push inheritedPredicates down to the source if they don't involve the semi join output EqualityInference inheritedInference = createEqualityInference(inheritedPredicate); @@ -883,7 +857,6 @@ public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext cont postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(sourceConjuncts)); - PlanNode rewrittenFilteringSource = context.rewrite(node.getFilteringSource(), combineConjuncts(filteringSourceConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource()) { @@ -938,8 +911,6 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext co @Override public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) { - Set predicateSymbols = DependencyExtractor.extractUnique(context.get()); + Set predicateSymbols = SymbolsExtractor.extractUnique(context.get()); checkState(!predicateSymbols.contains(node.getIdColumn()), "UniqueId in predicate is not yet supported"); return context.defaultRewrite(node, context.get()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java index f6453b4ebb2f7..7089a11370a0c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PreferredProperties.java @@ -68,10 +68,10 @@ public static PreferredProperties partitioned(Set columns) .build(); } - public static PreferredProperties partitionedWithNullsReplicated(Set columns) + public static PreferredProperties partitionedWithNullsAndAnyReplicated(Set columns) { return builder() - .global(Global.distributed(PartitioningProperties.partitioned(columns).withNullsReplicated(true))) + .global(Global.distributed(PartitioningProperties.partitioned(columns).withNullsAndAnyReplicated(true))) .build(); } @@ -89,10 +89,10 @@ public static PreferredProperties partitioned(Partitioning partitioning) .build(); } - public static PreferredProperties partitionedWithNullsReplicated(Partitioning partitioning) + public static PreferredProperties partitionedWithNullsAndAnyReplicated(Partitioning partitioning) { return builder() - .global(Global.distributed(PartitioningProperties.partitioned(partitioning).withNullsReplicated(true))) + .global(Global.distributed(PartitioningProperties.partitioned(partitioning).withNullsAndAnyReplicated(true))) .build(); } @@ -305,20 +305,20 @@ public static final class PartitioningProperties { private final Set partitioningColumns; private final Optional partitioning; // Specific partitioning requested - private final boolean nullsReplicated; + private final boolean nullsAndAnyReplicated; - private PartitioningProperties(Set partitioningColumns, Optional partitioning, boolean nullsReplicated) + private PartitioningProperties(Set partitioningColumns, Optional partitioning, boolean nullsAndAnyReplicated) { this.partitioningColumns = ImmutableSet.copyOf(requireNonNull(partitioningColumns, "partitioningColumns is null")); this.partitioning = requireNonNull(partitioning, "function is null"); - this.nullsReplicated = nullsReplicated; + this.nullsAndAnyReplicated = nullsAndAnyReplicated; checkArgument(!partitioning.isPresent() || partitioning.get().getColumns().equals(partitioningColumns), "Partitioning input must match partitioningColumns"); } - public PartitioningProperties withNullsReplicated(boolean nullsReplicated) + public PartitioningProperties withNullsAndAnyReplicated(boolean nullsAndAnyReplicated) { - return new PartitioningProperties(partitioningColumns, partitioning, nullsReplicated); + return new PartitioningProperties(partitioningColumns, partitioning, nullsAndAnyReplicated); } public static PartitioningProperties partitioned(Partitioning partitioning) @@ -346,9 +346,9 @@ public Optional getPartitioning() return partitioning; } - public boolean isNullsReplicated() + public boolean isNullsAndAnyReplicated() { - return nullsReplicated; + return nullsAndAnyReplicated; } public PartitioningProperties mergeWithParent(PartitioningProperties parent) @@ -358,8 +358,8 @@ public PartitioningProperties mergeWithParent(PartitioningProperties parent) return this; } - // Partitioning with different null replication cannot be compared - if (nullsReplicated != parent.nullsReplicated) { + // Partitioning with different replication cannot be compared + if (nullsAndAnyReplicated != parent.nullsAndAnyReplicated) { return this; } @@ -371,7 +371,7 @@ public PartitioningProperties mergeWithParent(PartitioningProperties parent) // Otherwise partition on any common columns if available Set common = Sets.intersection(partitioningColumns, parent.partitioningColumns); - return common.isEmpty() ? this : partitioned(common).withNullsReplicated(nullsReplicated); + return common.isEmpty() ? this : partitioned(common).withNullsAndAnyReplicated(nullsAndAnyReplicated); } public Optional translate(Function> translator) @@ -388,7 +388,7 @@ public Optional translate(Function newPartitioning = partitioning.get().translate(translator, symbol -> Optional.empty()); @@ -396,13 +396,13 @@ public Optional translate(Function, ActualProperties> + extends PlanVisitor> { private final Metadata metadata; private final Session session; @@ -186,7 +187,13 @@ public ActualProperties visitAssignUniqueId(AssignUniqueId node, List inputProperties) { - return inputProperties.get(0); // apply node input (outer query) + throw new IllegalArgumentException("Unexpected node: " + node.getClass().getName()); + } + + @Override + public ActualProperties visitLateralJoin(LateralJoinNode node, List inputProperties) + { + throw new IllegalArgumentException("Unexpected node: " + node.getClass().getName()); } @Override @@ -435,7 +442,7 @@ public static Map exchangeInputToOutput(ExchangeNode node, int s @Override public ActualProperties visitExchange(ExchangeNode node, List inputProperties) { - checkArgument(node.getScope() != REMOTE || inputProperties.stream().noneMatch(ActualProperties::isNullsReplicated), "Null replicated inputs should not be remotely exchanged"); + checkArgument(node.getScope() != REMOTE || inputProperties.stream().noneMatch(ActualProperties::isNullsAndAnyReplicated), "Null-and-any replicated inputs should not be remotely exchanged"); Set> entries = null; for (int sourceIndex = 0; sourceIndex < node.getSources().size(); sourceIndex++) { @@ -472,7 +479,7 @@ public ActualProperties visitExchange(ExchangeNode node, List .global(partitionedOn( node.getPartitioningScheme().getPartitioning(), Optional.of(node.getPartitioningScheme().getPartitioning())) - .withReplicatedNulls(node.getPartitioningScheme().isReplicateNulls())) + .withReplicatedNulls(node.getPartitioningScheme().isReplicateNullsAndAny())) .constants(constants) .build(); case REPLICATE: @@ -519,8 +526,8 @@ public ActualProperties visitProject(ProjectNode node, List in for (Map.Entry assignment : node.getAssignments().entrySet()) { Expression expression = assignment.getValue(); - IdentityLinkedHashMap expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList() /* parameters already replaced */); - Type type = requireNonNull(expressionTypes.get(expression)); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList() /* parameters already replaced */); + Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); // TODO: // We want to use a symbol resolver that looks up in the constants from the input subplan diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index dbbbf35915823..33c735e02effb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -14,15 +14,15 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; @@ -37,6 +37,7 @@ import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; @@ -77,6 +78,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -148,7 +150,7 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext> con node.getPartitioningScheme().getPartitioning(), newOutputSymbols, node.getPartitioningScheme().getHashColumn(), - node.getPartitioningScheme().isReplicateNulls(), + node.getPartitioningScheme().isReplicateNullsAndAny(), node.getPartitioningScheme().getBucketToPartition()); ImmutableList.Builder rewrittenSources = ImmutableList.builder(); @@ -176,7 +178,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext> context) Set expectedFilterInputs = new HashSet<>(); if (node.getFilter().isPresent()) { expectedFilterInputs = ImmutableSet.builder() - .addAll(DependencyExtractor.extractUnique(node.getFilter().get())) + .addAll(SymbolsExtractor.extractUnique(node.getFilter().get())) .addAll(context.get()) .build(); } @@ -306,22 +308,18 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext functions = ImmutableMap.builder(); - ImmutableMap.Builder functionCalls = ImmutableMap.builder(); - ImmutableMap.Builder masks = ImmutableMap.builder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { - FunctionCall call = entry.getValue(); - expectedInputs.addAll(DependencyExtractor.extractUnique(call)); - if (node.getMasks().containsKey(symbol)) { - expectedInputs.add(node.getMasks().get(symbol)); - masks.put(symbol, node.getMasks().get(symbol)); + Aggregation aggregation = entry.getValue(); + FunctionCall call = aggregation.getCall(); + expectedInputs.addAll(SymbolsExtractor.extractUnique(call)); + if (aggregation.getMask().isPresent()) { + expectedInputs.add(aggregation.getMask().get()); } - - functionCalls.put(symbol, call); - functions.put(symbol, node.getFunctions().get(symbol)); + aggregations.put(symbol, new Aggregation(call, aggregation.getSignature(), aggregation.getMask())); } } @@ -329,9 +327,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext> context if (context.get().contains(symbol)) { FunctionCall call = function.getFunctionCall(); - expectedInputs.addAll(DependencyExtractor.extractUnique(call)); + expectedInputs.addAll(SymbolsExtractor.extractUnique(call)); functionsBuilder.put(symbol, entry.getValue()); } @@ -417,7 +413,7 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext> c public PlanNode visitFilter(FilterNode node, RewriteContext> context) { Set expectedInputs = ImmutableSet.builder() - .addAll(DependencyExtractor.extractUnique(node.getPredicate())) + .addAll(SymbolsExtractor.extractUnique(node.getPredicate())) .addAll(context.get()) .build(); @@ -508,7 +504,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext> conte Expression expression = node.getAssignments().get(output); if (context.get().contains(output)) { - expectedInputs.addAll(DependencyExtractor.extractUnique(expression)); + expectedInputs.addAll(SymbolsExtractor.extractUnique(expression)); builder.put(output, expression); } } @@ -540,13 +536,13 @@ public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext expectedInputs; if (node.getHashSymbol().isPresent()) { - expectedInputs = ImmutableSet.copyOf(concat(node.getOutputSymbols(), ImmutableList.of(node.getHashSymbol().get()))); + expectedInputs = ImmutableSet.copyOf(concat(node.getDistinctSymbols(), ImmutableList.of(node.getHashSymbol().get()))); } else { - expectedInputs = ImmutableSet.copyOf(node.getOutputSymbols()); + expectedInputs = ImmutableSet.copyOf(node.getDistinctSymbols()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs); - return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), node.getHashSymbol()); + return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), node.getDistinctSymbols(), node.getHashSymbol()); } @Override @@ -558,7 +554,7 @@ public PlanNode visitTopN(TopNNode node, RewriteContext> context) PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); - return new TopNNode(node.getId(), source, node.getCount(), node.getOrderBy(), node.getOrderings(), node.isPartial()); + return new TopNNode(node.getId(), source, node.getCount(), node.getOrderBy(), node.getOrderings(), node.getStep()); } @Override @@ -739,7 +735,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext> context) Symbol output = entry.getKey(); Expression expression = entry.getValue(); if (context.get().contains(output)) { - subqueryAssignmentsSymbolsBuilder.addAll(DependencyExtractor.extractUnique(expression)); + subqueryAssignmentsSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression)); subqueryAssignments.put(output, expression); } } @@ -748,7 +744,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext> context) PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsSymbols); // prune not used correlation symbols - Set subquerySymbols = DependencyExtractor.extractUnique(subquery); + Set subquerySymbols = SymbolsExtractor.extractUnique(subquery); List newCorrelation = node.getCorrelation().stream() .filter(subquerySymbols::contains) .collect(toImmutableList()); @@ -770,5 +766,29 @@ public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext> context) + { + PlanNode subquery = context.rewrite(node.getSubquery(), context.get()); + + // remove unused lateral nodes + if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty() && isScalar(subquery)) { + return context.rewrite(node.getInput(), context.get()); + } + + // prune not used correlation symbols + Set subquerySymbols = SymbolsExtractor.extractUnique(subquery); + List newCorrelation = node.getCorrelation().stream() + .filter(subquerySymbols::contains) + .collect(toImmutableList()); + + Set inputContext = ImmutableSet.builder() + .addAll(context.get()) + .addAll(newCorrelation) + .build(); + PlanNode input = context.rewrite(node.getInput(), inputContext); + return new LateralJoinNode(node.getId(), input, subquery, newCorrelation, node.getType()); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarInputApplyNodes.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarLateralNodes.java similarity index 68% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarInputApplyNodes.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarLateralNodes.java index 359e61ba352c9..07e331302471f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarInputApplyNodes.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RemoveUnreferencedScalarLateralNodes.java @@ -19,7 +19,7 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; @@ -29,9 +29,10 @@ import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; /** - * Remove resolved ApplyNodes with unreferenced scalar input, e.g: "SELECT (SELECT 1)". + * Remove LateralJoinNodes with unreferenced scalar input, e.g: "SELECT (SELECT 1)". */ -public class RemoveUnreferencedScalarInputApplyNodes +@Deprecated +public class RemoveUnreferencedScalarLateralNodes implements PlanOptimizer { @Override @@ -44,13 +45,25 @@ private static class Rewriter extends SimplePlanRewriter { @Override - public PlanNode visitApply(ApplyNode node, RewriteContext context) + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) { - if (node.getInput().getOutputSymbols().isEmpty() && isScalar(node.getInput()) && node.isResolvedScalarSubquery()) { - return context.rewrite(node.getSubquery()); + PlanNode input = node.getInput(); + PlanNode subquery = node.getSubquery(); + + if (isUnreferencedScalar(input)) { + return context.rewrite(subquery); + } + + if (isUnreferencedScalar(subquery)) { + return context.rewrite(input); } return context.defaultRewrite(node); } + + private boolean isUnreferencedScalar(PlanNode input) + { + return input.getOutputSymbols().isEmpty() && isScalar(input); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java new file mode 100644 index 0000000000000..93ebbf295aa79 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.BooleanType; +import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator.DecorrelatedNode; +import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +// TODO: move this class to TransformCorrelatedScalarAggregationToJoin when old optimizer is gone +public class ScalarAggregationToJoinRewriter +{ + private static final QualifiedName COUNT = QualifiedName.of("count"); + + private final FunctionRegistry functionRegistry; + private final SymbolAllocator symbolAllocator; + private final PlanNodeIdAllocator idAllocator; + private final Lookup lookup; + private final PlanNodeDecorrelator planNodeDecorrelator; + + public ScalarAggregationToJoinRewriter(FunctionRegistry functionRegistry, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) + { + this.functionRegistry = requireNonNull(functionRegistry, "metadata is null"); + this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); + this.planNodeDecorrelator = new PlanNodeDecorrelator(idAllocator, lookup); + } + + public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregation) + { + List correlation = lateralJoinNode.getCorrelation(); + Optional source = planNodeDecorrelator.decorrelateFilters(lookup.resolve(aggregation.getSource()), correlation); + if (!source.isPresent()) { + return lateralJoinNode; + } + + Symbol nonNull = symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN); + Assignments scalarAggregationSourceAssignments = Assignments.builder() + .putAll(Assignments.identity(source.get().getNode().getOutputSymbols())) + .put(nonNull, TRUE_LITERAL) + .build(); + ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode( + idAllocator.getNextId(), + source.get().getNode(), + scalarAggregationSourceAssignments); + + return rewriteScalarAggregation( + lateralJoinNode, + aggregation, + scalarAggregationSourceWithNonNullableSymbol, + source.get().getCorrelatedPredicates(), + nonNull); + } + + private PlanNode rewriteScalarAggregation( + LateralJoinNode lateralJoinNode, + AggregationNode scalarAggregation, + PlanNode scalarAggregationSource, + Optional joinExpression, + Symbol nonNull) + { + AssignUniqueId inputWithUniqueColumns = new AssignUniqueId( + idAllocator.getNextId(), + lateralJoinNode.getInput(), + symbolAllocator.newSymbol("unique", BigintType.BIGINT)); + + JoinNode leftOuterJoin = new JoinNode( + idAllocator.getNextId(), + JoinNode.Type.LEFT, + inputWithUniqueColumns, + scalarAggregationSource, + ImmutableList.of(), + ImmutableList.builder() + .addAll(inputWithUniqueColumns.getOutputSymbols()) + .addAll(scalarAggregationSource.getOutputSymbols()) + .build(), + joinExpression, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + Optional aggregationNode = createAggregationNode( + scalarAggregation, + leftOuterJoin, + nonNull); + + if (!aggregationNode.isPresent()) { + return lateralJoinNode; + } + + Optional subqueryProjection = searchFrom(lateralJoinNode.getSubquery(), lookup) + .where(ProjectNode.class::isInstance) + .skipOnlyWhen(EnforceSingleRowNode.class::isInstance) + .findFirst(); + + List aggregationOutputSymbols = getTruncatedAggregationSymbols(lateralJoinNode, aggregationNode.get()); + + if (subqueryProjection.isPresent()) { + Assignments assignments = Assignments.builder() + .putAll(Assignments.identity(aggregationOutputSymbols)) + .putAll(subqueryProjection.get().getAssignments()) + .build(); + + return new ProjectNode( + idAllocator.getNextId(), + aggregationNode.get(), + assignments); + } + else { + Assignments assignments = Assignments.builder() + .putAll(Assignments.identity(aggregationOutputSymbols)) + .build(); + + return new ProjectNode( + idAllocator.getNextId(), + aggregationNode.get(), + assignments); + } + } + + private static List getTruncatedAggregationSymbols(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) + { + Set applySymbols = new HashSet<>(lateralJoinNode.getOutputSymbols()); + return aggregationNode.getOutputSymbols().stream() + .filter(symbol -> applySymbols.contains(symbol)) + .collect(toImmutableList()); + } + + private Optional createAggregationNode( + AggregationNode scalarAggregation, + JoinNode leftOuterJoin, + Symbol nonNullableAggregationSourceSymbol) + { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : scalarAggregation.getAggregations().entrySet()) { + FunctionCall call = entry.getValue().getCall(); + Symbol symbol = entry.getKey(); + if (call.getName().equals(COUNT)) { + List scalarAggregationSourceTypeSignatures = ImmutableList.of( + symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol).getTypeSignature()); + aggregations.put(symbol, new Aggregation( + new FunctionCall( + COUNT, + ImmutableList.of(nonNullableAggregationSourceSymbol.toSymbolReference())), + functionRegistry.resolveFunction( + COUNT, + fromTypeSignatures(scalarAggregationSourceTypeSignatures)), + entry.getValue().getMask())); + } + else { + aggregations.put(symbol, entry.getValue()); + } + } + + List groupBySymbols = leftOuterJoin.getLeft().getOutputSymbols(); + return Optional.of(new AggregationNode( + idAllocator.getNextId(), + leftOuterJoin, + aggregations.build(), + ImmutableList.of(groupBySymbols), + scalarAggregation.getStep(), + scalarAggregation.getHashSymbol(), + Optional.empty())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarQueryUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarQueryUtil.java index cb8d0397dc91c..5988b8a2b131e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarQueryUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarQueryUtil.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.sql.planner.optimizations; +import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -23,26 +25,46 @@ import com.facebook.presto.sql.planner.plan.ValuesNode; import com.google.common.collect.ImmutableList; +import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; public final class ScalarQueryUtil { private ScalarQueryUtil() {} + public static boolean isScalar(PlanNode node, Lookup lookup) + { + return node.accept(new IsScalarPlanVisitor(lookup), null); + } + public static boolean isScalar(PlanNode node) { - return node.accept(new IsScalarPlanVisitor(), null); + return isScalar(node, noLookup()); } private static final class IsScalarPlanVisitor - extends PlanVisitor + extends PlanVisitor { + private final Lookup lookup; + + public IsScalarPlanVisitor(Lookup lookup) + { + this.lookup = requireNonNull(lookup, "lookup is null"); + } + @Override protected Boolean visitPlan(PlanNode node, Void context) { return false; } + @Override + public Boolean visitGroupReference(GroupReference node, Void context) + { + return lookup.resolve(node).accept(this, context); + } + @Override public Boolean visitEnforceSingleRow(EnforceSingleRowNode node, Void context) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java index 2e969d827afc7..a0cd77ac0fb10 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java @@ -131,8 +131,6 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext c node.getId(), rewrittenNode, node.getAggregations(), - node.getFunctions(), - node.getMasks(), node.getGroupingSets(), node.getStep(), node.getHashSymbol(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java index a289ea3c171fb..de5b5d8cd58b1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyExpressions.java @@ -35,10 +35,10 @@ import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.LogicalBinaryExpression; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.SymbolReference; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; @@ -150,9 +150,9 @@ private Expression simplifyExpression(Expression expression) } expression = ExpressionTreeRewriter.rewriteWith(new PushDownNegationsExpressionRewriter(), expression); expression = ExpressionTreeRewriter.rewriteWith(new ExtractCommonPredicatesExpressionRewriter(), expression, NodeContext.ROOT_NODE); - IdentityLinkedHashMap expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList() /* parameters already replaced */); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList() /* parameters already replaced */); ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); - return LiteralInterpreter.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(expression)); + return LiteralInterpreter.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 98f0a9ce558dd..e85b10db86db0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -35,6 +35,7 @@ import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; @@ -135,7 +136,7 @@ public static StreamProperties deriveProperties(PlanNode node, List, StreamProperties> + extends PlanVisitor> { private final Metadata metadata; private final Session session; @@ -459,6 +460,10 @@ public StreamProperties visitTopNRowNumber(TopNRowNumberNode node, List inputProperties) { + // Partial TopN doesn't guarantee that stream is ordered + if (node.getStep().equals(TopNNode.Step.PARTIAL)) { + return Iterables.getOnlyElement(inputProperties); + } return StreamProperties.ordered(); } @@ -489,7 +494,13 @@ public StreamProperties visitSemiJoin(SemiJoinNode node, List @Override public StreamProperties visitApply(ApplyNode node, List inputProperties) { - return inputProperties.get(0); + throw new IllegalStateException("Unexpected node: " + node.getClass()); + } + + @Override + public StreamProperties visitLateralJoin(LateralJoinNode node, List inputProperties) + { + throw new IllegalStateException("Unexpected node: " + node.getClass()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 4d050c0b34626..80785b702f70c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -13,12 +13,14 @@ */ package com.facebook.presto.sql.planner.optimizations; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; @@ -78,18 +80,12 @@ public AggregationNode map(AggregationNode node, PlanNode source, PlanNodeIdAllo private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) { - ImmutableMap.Builder functionInfos = ImmutableMap.builder(); - ImmutableMap.Builder functionCalls = ImmutableMap.builder(); - ImmutableMap.Builder masks = ImmutableMap.builder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Map.Entry entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); - Symbol canonical = map(symbol); - FunctionCall canonicalCall = (FunctionCall) map(entry.getValue()); - functionCalls.put(canonical, canonicalCall); - functionInfos.put(canonical, node.getFunctions().get(symbol)); - } - for (Map.Entry entry : node.getMasks().entrySet()) { - masks.put(map(entry.getKey()), map(entry.getValue())); + Aggregation aggregation = entry.getValue(); + + aggregations.put(map(symbol), new Aggregation((FunctionCall) map(aggregation.getCall()), aggregation.getSignature(), aggregation.getMask().map(this::map))); } List> groupingSets = node.getGroupingSets().stream() @@ -99,15 +95,36 @@ private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId ne return new AggregationNode( newNodeId, source, - functionCalls.build(), - functionInfos.build(), - masks.build(), + aggregations.build(), groupingSets, node.getStep(), node.getHashSymbol().map(this::map), node.getGroupIdSymbol().map(this::map)); } + public TopNNode map(TopNNode node, PlanNode source, PlanNodeId newNodeId) + { + ImmutableList.Builder symbols = ImmutableList.builder(); + ImmutableMap.Builder orderings = ImmutableMap.builder(); + Set seenCanonicals = new HashSet<>(node.getOrderBy().size()); + for (Symbol symbol : node.getOrderBy()) { + Symbol canonical = map(symbol); + if (seenCanonicals.add(canonical)) { + seenCanonicals.add(canonical); + symbols.add(canonical); + orderings.put(canonical, node.getOrderings().get(symbol)); + } + } + + return new TopNNode( + newNodeId, + source, + node.getCount(), + symbols.build(), + orderings.build(), + node.getStep()); + } + private List mapAndDistinct(List outputs) { Set added = new HashSet<>(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedNoAggregationSubqueryToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedNoAggregationSubqueryToJoin.java new file mode 100644 index 0000000000000..154fc6ab51f94 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedNoAggregationSubqueryToJoin.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator.DecorrelatedNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; +import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static java.util.Objects.requireNonNull; + +/** + * This optimizer can rewrite correlated no aggregation subquery to inner join in a way described here: + * From: + *

+ * - Lateral (with correlation list: [B])
+ *   - (input) plan which produces symbols: [A, B]
+ *   - (subquery)
+ *     - Filter(B = C AND D < 5)
+ *       - plan which produces symbols: [C, D]
+ * 
+ * to: + *
+ *   - Join(INNER, B = C)
+ *       - (input) plan which produces symbols: [A, B]
+ *       - Filter(D < 5)
+ *          - plan which produces symbols: [C, D]
+ * 
+ *

+ * Note only conjunction predicates in FilterNode are supported + */ +public class TransformCorrelatedNoAggregationSubqueryToJoin + implements PlanOptimizer +{ + @Override + public PlanNode optimize( + PlanNode plan, + Session session, + Map types, + SymbolAllocator symbolAllocator, + PlanNodeIdAllocator idAllocator) + { + return rewriteWith(new Rewriter(idAllocator), plan, null); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + public Rewriter(PlanNodeIdAllocator idAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + } + + @Override + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) + { + LateralJoinNode rewrittenNode = (LateralJoinNode) context.defaultRewrite(node, context.get()); + if (!rewrittenNode.getCorrelation().isEmpty()) { + return rewriteNoAggregationSubquery(rewrittenNode); + } + return rewrittenNode; + } + + private PlanNode rewriteNoAggregationSubquery(LateralJoinNode lateral) + { + List correlation = lateral.getCorrelation(); + PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(idAllocator, noLookup()); + Optional source = decorrelator.decorrelateFilters(lateral.getSubquery(), correlation); + if (!source.isPresent()) { + return lateral; + } + + return new JoinNode( + idAllocator.getNextId(), + JoinNode.Type.INNER, + lateral.getInput(), + source.get().getNode(), + ImmutableList.of(), + lateral.getOutputSymbols(), + source.get().getCorrelatedPredicates(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java index 871fb577077a1..71b91696d1a0c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedScalarAggregationToJoin.java @@ -15,50 +15,24 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionRegistry; -import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; -import com.facebook.presto.spi.type.BigintType; -import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.sql.ExpressionUtils; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.ApplyNode; -import com.facebook.presto.sql.planner.plan.AssignUniqueId; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; -import com.facebook.presto.sql.planner.plan.FilterNode; -import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.tree.DefaultTraversalVisitor; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.LogicalBinaryExpression; -import com.facebook.presto.sql.tree.QualifiedName; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; -import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Predicate; -import java.util.stream.Collectors; -import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; +import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.sql.planner.optimizations.Predicates.isInstanceOfAny; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; -import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; import static java.util.Objects.requireNonNull; /** @@ -70,7 +44,7 @@ *

* From: *

- * - Apply (with correlation list: [C])
+ * - LateralJoin (with correlation list: [C])
  *   - (input) plan which produces symbols: [A, B, C]
  *   - (subquery) Aggregation(GROUP BY (); functions: [sum(F), count(), ...]
  *     - Filter(D = C AND E > 5)
@@ -89,14 +63,15 @@
  * 

* Note only conjunction predicates in FilterNode are supported */ +@Deprecated public class TransformCorrelatedScalarAggregationToJoin implements PlanOptimizer { - private final Metadata metadata; + private final FunctionRegistry functionRegistry; - public TransformCorrelatedScalarAggregationToJoin(Metadata metadata) + public TransformCorrelatedScalarAggregationToJoin(FunctionRegistry functionRegistry) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null"); } @Override @@ -107,7 +82,7 @@ public PlanNode optimize( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { - return rewriteWith(new Rewriter(idAllocator, symbolAllocator, metadata), plan, null); + return rewriteWith(new Rewriter(idAllocator, symbolAllocator, functionRegistry), plan, null); } private static class Rewriter @@ -115,303 +90,30 @@ private static class Rewriter { private final PlanNodeIdAllocator idAllocator; private final SymbolAllocator symbolAllocator; - private final Metadata metadata; + private final FunctionRegistry functionRegistry; - public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) + public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, FunctionRegistry functionRegistry) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); - this.metadata = requireNonNull(metadata, "metadata is null"); + this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null"); } @Override - public PlanNode visitApply(ApplyNode node, RewriteContext context) + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) { - ApplyNode rewrittenNode = (ApplyNode) context.defaultRewrite(node, context.get()); - if (!rewrittenNode.getCorrelation().isEmpty() && rewrittenNode.isResolvedScalarSubquery()) { + LateralJoinNode rewrittenNode = (LateralJoinNode) context.defaultRewrite(node, context.get()); + if (!rewrittenNode.getCorrelation().isEmpty()) { Optional aggregation = searchFrom(rewrittenNode.getSubquery()) .where(AggregationNode.class::isInstance) .skipOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) .findFirst(); if (aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty()) { - return rewriteScalarAggregation(rewrittenNode, aggregation.get()); + ScalarAggregationToJoinRewriter scalarAggregationToJoinRewriter = new ScalarAggregationToJoinRewriter(functionRegistry, symbolAllocator, idAllocator, noLookup()); + return scalarAggregationToJoinRewriter.rewriteScalarAggregation(rewrittenNode, aggregation.get()); } } return rewrittenNode; } - - private PlanNode rewriteScalarAggregation(ApplyNode apply, AggregationNode aggregation) - { - List correlation = apply.getCorrelation(); - Optional source = decorrelateFilters(aggregation.getSource(), correlation); - if (!source.isPresent()) { - return apply; - } - - Symbol nonNull = symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN); - Assignments scalarAggregationSourceAssignments = Assignments.builder() - .putAll(Assignments.identity(source.get().getNode().getOutputSymbols())) - .put(nonNull, TRUE_LITERAL) - .build(); - ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode( - idAllocator.getNextId(), - source.get().getNode(), - scalarAggregationSourceAssignments); - - return rewriteScalarAggregation( - apply, - aggregation, - scalarAggregationSourceWithNonNullableSymbol, - source.get().getCorrelatedPredicates(), - nonNull); - } - - private PlanNode rewriteScalarAggregation( - ApplyNode applyNode, - AggregationNode scalarAggregation, - PlanNode scalarAggregationSource, - Optional joinExpression, - Symbol nonNull) - { - AssignUniqueId inputWithUniqueColumns = new AssignUniqueId( - idAllocator.getNextId(), - applyNode.getInput(), - symbolAllocator.newSymbol("unique", BigintType.BIGINT)); - - JoinNode leftOuterJoin = new JoinNode( - idAllocator.getNextId(), - JoinNode.Type.LEFT, - inputWithUniqueColumns, - scalarAggregationSource, - ImmutableList.of(), - ImmutableList.builder() - .addAll(inputWithUniqueColumns.getOutputSymbols()) - .addAll(scalarAggregationSource.getOutputSymbols()) - .build(), - joinExpression, - Optional.empty(), - Optional.empty(), - Optional.empty()); - - Optional aggregationNode = createAggregationNode( - scalarAggregation, - leftOuterJoin, - nonNull); - - if (!aggregationNode.isPresent()) { - return applyNode; - } - - Optional subqueryProjection = searchFrom(applyNode.getSubquery()) - .where(ProjectNode.class::isInstance) - .skipOnlyWhen(EnforceSingleRowNode.class::isInstance) - .findFirst(); - - if (subqueryProjection.isPresent()) { - Assignments assignments = Assignments.builder() - .putAll(Assignments.identity(aggregationNode.get().getOutputSymbols())) - .putAll(subqueryProjection.get().getAssignments()) - .build(); - - return new ProjectNode( - idAllocator.getNextId(), - aggregationNode.get(), - assignments); - } - else { - return aggregationNode.get(); - } - } - - private Optional createAggregationNode( - AggregationNode scalarAggregation, - JoinNode leftOuterJoin, - Symbol nonNullableAggregationSourceSymbol) - { - ImmutableMap.Builder aggregations = ImmutableMap.builder(); - ImmutableMap.Builder functions = ImmutableMap.builder(); - FunctionRegistry functionRegistry = metadata.getFunctionRegistry(); - for (Map.Entry entry : scalarAggregation.getAggregations().entrySet()) { - FunctionCall call = entry.getValue(); - QualifiedName count = QualifiedName.of("count"); - Symbol symbol = entry.getKey(); - if (call.getName().equals(count)) { - aggregations.put(symbol, new FunctionCall( - count, - ImmutableList.of(nonNullableAggregationSourceSymbol.toSymbolReference()))); - List scalarAggregationSourceTypeSignatures = ImmutableList.of( - symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol).getTypeSignature()); - functions.put(symbol, functionRegistry.resolveFunction( - count, - fromTypeSignatures(scalarAggregationSourceTypeSignatures))); - } - else { - aggregations.put(symbol, entry.getValue()); - functions.put(symbol, scalarAggregation.getFunctions().get(symbol)); - } - } - - List groupBySymbols = leftOuterJoin.getLeft().getOutputSymbols(); - return Optional.of(new AggregationNode( - idAllocator.getNextId(), - leftOuterJoin, - aggregations.build(), - functions.build(), - scalarAggregation.getMasks(), - ImmutableList.of(groupBySymbols), - scalarAggregation.getStep(), - scalarAggregation.getHashSymbol(), - Optional.empty())); - } - - private Optional decorrelateFilters(PlanNode node, List correlation) - { - PlanNodeSearcher filterNodeSearcher = searchFrom(node) - .where(FilterNode.class::isInstance) - .skipOnlyWhen(isInstanceOfAny(ProjectNode.class)); - List filterNodes = filterNodeSearcher.findAll(); - - if (filterNodes.isEmpty()) { - return decorrelatedNode(ImmutableList.of(), node, correlation); - } - - if (filterNodes.size() > 1) { - return Optional.empty(); - } - - FilterNode filterNode = filterNodes.get(0); - Expression predicate = filterNode.getPredicate(); - - if (!isSupportedPredicate(predicate)) { - return Optional.empty(); - } - - if (!DependencyExtractor.extractUnique(predicate).containsAll(correlation)) { - return Optional.empty(); - } - - Map> predicates = ExpressionUtils.extractConjuncts(predicate).stream() - .collect(Collectors.partitioningBy(isUsingPredicate(correlation))); - List correlatedPredicates = ImmutableList.copyOf(predicates.get(true)); - List uncorrelatedPredicates = ImmutableList.copyOf(predicates.get(false)); - - node = updateFilterNode(filterNodeSearcher, uncorrelatedPredicates); - node = ensureJoinSymbolsAreReturned(node, correlatedPredicates); - - return decorrelatedNode(correlatedPredicates, node, correlation); - } - - private static Optional decorrelatedNode( - List correlatedPredicates, - PlanNode node, - List correlation) - { - if (DependencyExtractor.extractUnique(node).stream().anyMatch(correlation::contains)) { - // node is still correlated ; / - return Optional.empty(); - } - return Optional.of(new DecorrelatedNode(correlatedPredicates, node)); - } - - private static Predicate isUsingPredicate(List symbols) - { - return expression -> symbols.stream().anyMatch(DependencyExtractor.extractUnique(expression)::contains); - } - - private PlanNode updateFilterNode(PlanNodeSearcher filterNodeSearcher, List newPredicates) - { - if (newPredicates.isEmpty()) { - return filterNodeSearcher.removeAll(); - } - FilterNode oldFilterNode = Iterables.getOnlyElement(filterNodeSearcher.findAll()); - FilterNode newFilterNode = new FilterNode( - idAllocator.getNextId(), - oldFilterNode.getSource(), - ExpressionUtils.combineConjuncts(newPredicates)); - return filterNodeSearcher.replaceAll(newFilterNode); - } - - private PlanNode ensureJoinSymbolsAreReturned(PlanNode scalarAggregationSource, List joinPredicate) - { - Set joinExpressionSymbols = DependencyExtractor.extractUnique(joinPredicate); - ExtendProjectionRewriter extendProjectionRewriter = new ExtendProjectionRewriter( - idAllocator, - joinExpressionSymbols); - return rewriteWith(extendProjectionRewriter, scalarAggregationSource); - } - - private static boolean isSupportedPredicate(Expression predicate) - { - AtomicBoolean isSupported = new AtomicBoolean(true); - new DefaultTraversalVisitor() - { - @Override - protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, AtomicBoolean context) - { - if (node.getType() != LogicalBinaryExpression.Type.AND) { - context.set(false); - } - return null; - } - }.process(predicate, isSupported); - return isSupported.get(); - } - } - - private static class DecorrelatedNode - { - private final List correlatedPredicates; - private final PlanNode node; - - public DecorrelatedNode(List correlatedPredicates, PlanNode node) - { - requireNonNull(correlatedPredicates, "correlatedPredicates is null"); - this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates); - this.node = requireNonNull(node, "node is null"); - } - - Optional getCorrelatedPredicates() - { - if (correlatedPredicates.isEmpty()) { - return Optional.empty(); - } - return Optional.of(ExpressionUtils.and(correlatedPredicates)); - } - - public PlanNode getNode() - { - return node; - } - } - - private static class ExtendProjectionRewriter - extends SimplePlanRewriter - { - private final PlanNodeIdAllocator idAllocator; - private final Set symbols; - - ExtendProjectionRewriter(PlanNodeIdAllocator idAllocator, Set symbols) - { - this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); - this.symbols = requireNonNull(symbols, "symbols is null"); - } - - @Override - public PlanNode visitProject(ProjectNode node, RewriteContext context) - { - ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node, context.get()); - - List symbolsToAdd = symbols.stream() - .filter(rewrittenNode.getSource().getOutputSymbols()::contains) - .filter(symbol -> !rewrittenNode.getOutputSymbols().contains(symbol)) - .collect(toImmutableList()); - - Assignments assignments = Assignments.builder() - .putAll(rewrittenNode.getAssignments()) - .putAll(Assignments.identity(symbolsToAdd)) - .build(); - - return new ProjectNode(idAllocator.getNextId(), rewrittenNode.getSource(), assignments); - } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedSingleRowSubqueryToProject.java new file mode 100644 index 0000000000000..8ddaaa6edb9a2 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformCorrelatedSingleRowSubqueryToProject.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.facebook.presto.sql.planner.plan.ValuesNode; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static java.util.Objects.requireNonNull; + +/** + * This optimizer can rewrite correlated single row subquery to projection in a way described here: + * From: + *

+ * - Lateral(with correlation list: [A, C])
+ *   - (input) plan which produces symbols: [A, B, C]
+ *   - (subquery)
+ *     - Project (A + C)
+ *       - single row VALUES()
+ * 
+ * to: + *
+ *   - Project(A, B, C, A + C)
+ *       - (input) plan which produces symbols: [A, B, C]
+ * 
+ */ +public class TransformCorrelatedSingleRowSubqueryToProject + implements PlanOptimizer +{ + @Override + public PlanNode optimize( + PlanNode plan, + Session session, + Map types, + SymbolAllocator symbolAllocator, + PlanNodeIdAllocator idAllocator) + { + return rewriteWith(new Rewriter(idAllocator), plan, null); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + + public Rewriter(PlanNodeIdAllocator idAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + } + + @Override + public PlanNode visitLateralJoin(LateralJoinNode lateral, RewriteContext context) + { + LateralJoinNode rewrittenLateral = (LateralJoinNode) context.defaultRewrite(lateral, context.get()); + if (rewrittenLateral.getCorrelation().isEmpty()) { + return rewrittenLateral; + } + + Optional values = searchFrom(lateral.getSubquery()) + .skipOnlyWhen(ProjectNode.class::isInstance) + .where(ValuesNode.class::isInstance) + .findSingle(); + + if (!values.isPresent() || !isSingleRowValuesWithNoColumns(values.get())) { + return rewrittenLateral; + } + + List subqueryProjections = searchFrom(lateral.getSubquery()) + .where(ProjectNode.class::isInstance) + .findAll(); + + if (subqueryProjections.size() == 0) { + return rewrittenLateral.getInput(); + } + else if (subqueryProjections.size() == 1) { + Assignments assignments = Assignments.builder() + .putIdentities(rewrittenLateral.getInput().getOutputSymbols()) + .putAll(subqueryProjections.get(0).getAssignments()) + .build(); + return projectNode(rewrittenLateral.getInput(), assignments); + } + return rewrittenLateral; + } + + private ProjectNode projectNode(PlanNode source, Assignments assignments) + { + return new ProjectNode(idAllocator.getNextId(), source, assignments); + } + + private static boolean isSingleRowValuesWithNoColumns(ValuesNode values) + { + return values.getRows().size() == 1 && values.getRows().get(0).size() == 0; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToScalarApply.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java similarity index 87% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToScalarApply.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index 5e4d9c8f0cc24..dd9f15d1b4b77 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToScalarApply.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -25,8 +25,10 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; @@ -68,12 +70,12 @@ import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; -public class TransformQuantifiedComparisonApplyToScalarApply +public class TransformQuantifiedComparisonApplyToLateralJoin implements PlanOptimizer { private final Metadata metadata; - public TransformQuantifiedComparisonApplyToScalarApply(Metadata metadata) + public TransformQuantifiedComparisonApplyToLateralJoin(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } @@ -143,35 +145,40 @@ private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, QuantifiedComparison idAllocator.getNextId(), subqueryPlan, ImmutableMap.of( - minValue, new FunctionCall(MIN, outputColumnReferences), - maxValue, new FunctionCall(MAX, outputColumnReferences), - countAllValue, new FunctionCall(COUNT, emptyList()), - countNonNullValue, new FunctionCall(COUNT, outputColumnReferences) + minValue, new Aggregation( + new FunctionCall(MIN, outputColumnReferences), + functionRegistry.resolveFunction(MIN, fromTypeSignatures(outputColumnTypeSignature)), + Optional.empty()), + maxValue, new Aggregation( + new FunctionCall(MAX, outputColumnReferences), + functionRegistry.resolveFunction(MAX, fromTypeSignatures(outputColumnTypeSignature)), + Optional.empty()), + countAllValue, new Aggregation( + new FunctionCall(COUNT, emptyList()), + functionRegistry.resolveFunction(COUNT, emptyList()), + Optional.empty()), + countNonNullValue, new Aggregation( + new FunctionCall(COUNT, outputColumnReferences), + functionRegistry.resolveFunction(COUNT, fromTypeSignatures(outputColumnTypeSignature)), + Optional.empty()) ), - ImmutableMap.of( - minValue, functionRegistry.resolveFunction(MIN, fromTypeSignatures(outputColumnTypeSignature)), - maxValue, functionRegistry.resolveFunction(MAX, fromTypeSignatures(outputColumnTypeSignature)), - countAllValue, functionRegistry.resolveFunction(COUNT, emptyList()), - countNonNullValue, functionRegistry.resolveFunction(COUNT, fromTypeSignatures(outputColumnTypeSignature)) - ), - ImmutableMap.of(), ImmutableList.of(ImmutableList.of()), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()); - PlanNode applyNode = new ApplyNode( + PlanNode lateralJoinNode = new LateralJoinNode( node.getId(), context.rewrite(node.getInput()), subqueryPlan, - Assignments.identity(minValue, maxValue), - node.getCorrelation()); + node.getCorrelation(), + LateralJoinNode.Type.INNER); Expression valueComparedToSubquery = rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue); Symbol quantifiedComparisonSymbol = getOnlyElement(node.getSubqueryAssignments().getSymbols()); - return projectExpressions(applyNode, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery)); + return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery)); } public Expression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java index 9db6ecddb875d..f2efb629e2a12 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -55,6 +55,7 @@ * - semiJoinOutput: semijoinresult *
*/ +@Deprecated public class TransformUncorrelatedInPredicateSubqueryToSemiJoin implements PlanOptimizer { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedScalarToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedLateralToJoin.java similarity index 87% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedScalarToJoin.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedLateralToJoin.java index c6ec646c7e33c..f9fe5e1010ccb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedScalarToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformUncorrelatedLateralToJoin.java @@ -18,8 +18,8 @@ import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; -import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.google.common.collect.ImmutableList; @@ -29,7 +29,8 @@ import static java.util.Objects.requireNonNull; -public class TransformUncorrelatedScalarToJoin +@Deprecated +public class TransformUncorrelatedLateralToJoin implements PlanOptimizer { @Override @@ -49,10 +50,10 @@ public Rewriter(PlanNodeIdAllocator idAllocator) } @Override - public PlanNode visitApply(ApplyNode node, RewriteContext context) + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) { - ApplyNode rewrittenNode = (ApplyNode) context.defaultRewrite(node, context.get()); - if (rewrittenNode.getCorrelation().isEmpty() && rewrittenNode.isResolvedScalarSubquery()) { + LateralJoinNode rewrittenNode = (LateralJoinNode) context.defaultRewrite(node, context.get()); + if (rewrittenNode.getCorrelation().isEmpty()) { return new JoinNode( idAllocator.getNextId(), JoinNode.Type.INNER, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 16c9b29e6a29a..78b6711c3292a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -38,6 +38,7 @@ import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; @@ -276,7 +277,7 @@ public PlanNode visitExchange(ExchangeNode node, RewriteContext context) node.getPartitioningScheme().getPartitioning().translate(this::canonicalize), outputs.build(), canonicalize(node.getPartitioningScheme().getHashColumn()), - node.getPartitioningScheme().isReplicateNulls(), + node.getPartitioningScheme().isReplicateNullsAndAny(), node.getPartitioningScheme().getBucketToPartition()); return new ExchangeNode(node.getId(), node.getType(), node.getScope(), partitioningScheme, sources, inputs); @@ -343,7 +344,7 @@ public PlanNode visitLimit(LimitNode node, RewriteContext context) @Override public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext context) { - return new DistinctLimitNode(node.getId(), context.rewrite(node.getSource()), node.getLimit(), node.isPartial(), canonicalize(node.getHashSymbol())); + return new DistinctLimitNode(node.getId(), context.rewrite(node.getSource()), node.getLimit(), node.isPartial(), canonicalizeAndDistinct(node.getDistinctSymbols()), canonicalize(node.getHashSymbol())); } @Override @@ -439,20 +440,23 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) return new ApplyNode(node.getId(), source, subquery, canonicalize(node.getSubqueryAssignments()), canonicalCorrelation); } + @Override + public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getInput()); + PlanNode subquery = context.rewrite(node.getSubquery()); + List canonicalCorrelation = canonicalizeAndDistinct(node.getCorrelation()); + + return new LateralJoinNode(node.getId(), source, subquery, canonicalCorrelation, node.getType()); + } + @Override public PlanNode visitTopN(TopNNode node, RewriteContext context) { PlanNode source = context.rewrite(node.getSource()); - ImmutableList.Builder symbols = ImmutableList.builder(); - ImmutableMap.Builder orderings = ImmutableMap.builder(); - for (Symbol symbol : node.getOrderBy()) { - Symbol canonical = canonicalize(symbol); - symbols.add(canonical); - orderings.put(canonical, node.getOrderings().get(symbol)); - } - - return new TopNNode(node.getId(), source, node.getCount(), symbols.build(), orderings.build(), node.isPartial()); + SymbolMapper mapper = new SymbolMapper(mapping); + return mapper.map(node, source, node.getId()); } @Override @@ -733,7 +737,7 @@ private PartitioningScheme canonicalizePartitionFunctionBinding(PartitioningSche scheme.getPartitioning().translate(this::canonicalize), outputs.build(), canonicalize(scheme.getHashColumn()), - scheme.isReplicateNulls(), + scheme.isReplicateNullsAndAny(), scheme.getBucketToPartition()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java new file mode 100644 index 0000000000000..e032b5f525a79 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowNodeUtil.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.plan.WindowNode; + +public final class WindowNodeUtil +{ + private WindowNodeUtil() {} + + public static boolean dependsOn(WindowNode parent, WindowNode child) + { + return parent.getPartitionBy().stream().anyMatch(child.getCreatedSymbols()::contains) + || parent.getOrderBy().stream().anyMatch(child.getCreatedSymbols()::contains) + || parent.getWindowFunctions().values().stream() + .map(WindowNode.Function::getFunctionCall) + .map(SymbolsExtractor::extractUnique) + .flatMap(symbols -> symbols.stream()) + .anyMatch(child.getCreatedSymbols()::contains); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java index afc797fbdcf20..c1f94166e40d2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java @@ -14,6 +14,8 @@ package com.facebook.presto.sql.planner.optimizations.joins; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -34,6 +36,7 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -50,10 +53,27 @@ public class JoinGraph private final Multimap edges; private final PlanNodeId rootId; + /** + * Builds all (distinct) {@link JoinGraph}-es whole plan tree. + */ public static List buildFrom(PlanNode plan) + { + return buildFrom(plan, Lookup.noLookup()); + } + + /** + * Builds {@link JoinGraph} containing {@code plan} node. + */ + public static JoinGraph buildShallowFrom(PlanNode plan, Lookup lookup) + { + JoinGraph graph = plan.accept(new Builder(true, lookup), new Context()); + return graph; + } + + private static List buildFrom(PlanNode plan, Lookup lookup) { Context context = new Context(); - JoinGraph graph = plan.accept(new Builder(), context); + JoinGraph graph = plan.accept(new Builder(false, lookup), context); if (graph.size() > 1) { context.addSubGraph(graph); } @@ -197,17 +217,29 @@ private JoinGraph joinWith(JoinGraph other, List joinCl } private static class Builder - extends PlanVisitor + extends PlanVisitor { + // TODO When com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'shallow' flag + private final boolean shallow; + private final Lookup lookup; + + private Builder(boolean shallow, Lookup lookup) + { + this.shallow = shallow; + this.lookup = requireNonNull(lookup, "lookup cannot be null"); + } + @Override protected JoinGraph visitPlan(PlanNode node, Context context) { - for (PlanNode child : node.getSources()) { - JoinGraph graph = child.accept(this, context); - if (graph.size() < 2) { - continue; + if (!shallow) { + for (PlanNode child : node.getSources()) { + JoinGraph graph = child.accept(this, context); + if (graph.size() < 2) { + continue; + } + context.addSubGraph(graph.withRootId(child.getId())); } - context.addSubGraph(graph.withRootId(child.getId())); } for (Symbol symbol : node.getOutputSymbols()) { @@ -251,6 +283,34 @@ public JoinGraph visitProject(ProjectNode node, Context context) } return visitPlan(node, context); } + + @Override + public JoinGraph visitGroupReference(GroupReference node, Context context) + { + PlanNode dereferenced = lookup.resolve(node); + JoinGraph graph = dereferenced.accept(this, context); + if (isTrivialGraph(graph)) { + return replacementGraph(dereferenced, node, context); + } + return graph; + } + + private boolean isTrivialGraph(JoinGraph graph) + { + return graph.nodes.size() < 2 && graph.edges.isEmpty() && graph.filters.isEmpty() && !graph.assignments.isPresent(); + } + + private JoinGraph replacementGraph(PlanNode oldNode, PlanNode newNode, Context context) + { + // TODO optimize when idea is generally approved + List symbols = context.symbolSources.entrySet().stream() + .filter(entry -> entry.getValue() == oldNode) + .map(Map.Entry::getKey) + .collect(toImmutableList()); + symbols.forEach(symbol -> context.symbolSources.put(symbol, newNode)); + + return new JoinGraph(newNode); + } } public static class Edge @@ -285,6 +345,8 @@ public Symbol getTargetSymbol() private static class Context { private final Map symbolSources = new HashMap<>(); + + // TODO When com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'joinGraphs' private final List joinGraphs = new ArrayList<>(); public void setSymbolSource(Symbol symbol, PlanNode node) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java index b20072d751d90..cbd9ed91e54e6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java @@ -34,6 +34,7 @@ import java.util.stream.Collectors; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.util.MoreLists.listOfListsCopy; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -42,7 +43,7 @@ public class AggregationNode extends PlanNode { private final PlanNode source; - private final Map assignments; + private final Map aggregations; private final List> groupingSets; private final Step step; private final Optional hashSymbol; @@ -110,7 +111,7 @@ public static Step partialInput(Step step) public AggregationNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("assignments") Map assignments, + @JsonProperty("aggregations") Map aggregations, @JsonProperty("groupingSets") List> groupingSets, @JsonProperty("step") Step step, @JsonProperty("hashSymbol") Optional hashSymbol, @@ -119,10 +120,10 @@ public AggregationNode( super(id); this.source = source; - this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "aggregations is null")); + this.aggregations = ImmutableMap.copyOf(requireNonNull(aggregations, "aggregations is null")); requireNonNull(groupingSets, "groupingSets is null"); checkArgument(!groupingSets.isEmpty(), "grouping sets list cannot be empty"); - this.groupingSets = ImmutableList.copyOf(groupingSets); + this.groupingSets = listOfListsCopy(groupingSets); this.step = step; this.hashSymbol = hashSymbol; this.groupIdSymbol = requireNonNull(groupIdSymbol); @@ -130,29 +131,11 @@ public AggregationNode( ImmutableList.Builder outputs = ImmutableList.builder(); outputs.addAll(getGroupingKeys()); hashSymbol.ifPresent(outputs::add); - outputs.addAll(assignments.keySet()); + outputs.addAll(aggregations.keySet()); this.outputs = outputs.build(); } - /** - * @deprecated pass Assignments object instead - */ - @Deprecated - public AggregationNode( - PlanNodeId id, - PlanNode source, - Map assignments, - Map functions, - Map masks, - List> groupingSets, - Step step, - Optional hashSymbol, - Optional groupIdSymbol) - { - this(id, source, makeAssignments(assignments, functions, masks), groupingSets, step, hashSymbol, groupIdSymbol); - } - @Override public List getSources() { @@ -166,56 +149,9 @@ public List getOutputSymbols() } @JsonProperty - public Map getAssignments() + public Map getAggregations() { - return assignments; - } - - /** - * @deprecated Use getAssignments - */ - @Deprecated - public Map getAggregations() - { - // use an ImmutableMap.Builder because the output has to preserve - // the iteration order of the original map. - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry entry : assignments.entrySet()) { - builder.put(entry.getKey(), entry.getValue().getCall()); - } - return builder.build(); - } - - /** - * @deprecated Use getAssignments - */ - @Deprecated - public Map getFunctions() - { - // use an ImmutableMap.Builder because the output has to preserve - // the iteration order of the original map. - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry entry : assignments.entrySet()) { - builder.put(entry.getKey(), entry.getValue().getSignature()); - } - return builder.build(); - } - - /** - * @deprecated Use getAssignments - */ - @Deprecated - public Map getMasks() - { - // use an ImmutableMap.Builder because the output has to preserve - // the iteration order of the original map. - ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry entry : assignments.entrySet()) { - entry.getValue() - .getMask() - .ifPresent(symbol -> builder.put(entry.getKey(), symbol)); - } - return builder.build(); + return aggregations; } public List getGroupingKeys() @@ -273,7 +209,7 @@ public Optional getGroupIdSymbol() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitAggregation(this, context); } @@ -281,34 +217,16 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new AggregationNode(getId(), Iterables.getOnlyElement(newChildren), assignments, groupingSets, step, hashSymbol, groupIdSymbol); + return new AggregationNode(getId(), Iterables.getOnlyElement(newChildren), aggregations, groupingSets, step, hashSymbol, groupIdSymbol); } public boolean isDecomposable(FunctionRegistry functionRegistry) { - return getFunctions().values().stream() - .map(functionRegistry::getAggregateFunctionImplementation) + return getAggregations().entrySet().stream() + .map(entry -> functionRegistry.getAggregateFunctionImplementation(entry.getValue().getSignature())) .allMatch(InternalAggregationFunction::isDecomposable); } - private static Map makeAssignments( - Map aggregations, - Map functions, - Map masks) - { - ImmutableMap.Builder builder = ImmutableMap.builder(); - - for (Map.Entry entry : aggregations.entrySet()) { - Symbol output = entry.getKey(); - builder.put(output, new Aggregation( - entry.getValue(), - functions.get(output), - Optional.ofNullable(masks.get(output)))); - } - - return builder.build(); - } - public static class Aggregation { private final FunctionCall call; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java index 1fdbca94eb0f6..61ae607603e0c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java @@ -14,7 +14,10 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.sql.tree.ExistsPredicate; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -23,7 +26,6 @@ import java.util.List; -import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -57,9 +59,6 @@ public class ApplyNode * - expression: input_symbol_X < ALL (subquery_symbol_Y) * - meaning: if input_symbol_X is smaller than all subquery values represented by subquery_symbol_Y *

- * Example 3: - * - expression: subquery_symbol_Y - * - meaning: subquery is scalar (might be enforced), therefore subquery_symbol_Y can be used directly in the rest of the plan */ private final Assignments subqueryAssignments; @@ -78,6 +77,9 @@ public ApplyNode( requireNonNull(correlation, "correlation is null"); checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); + checkArgument( + subqueryAssignments.getExpressions().stream().allMatch(ApplyNode::isSupportedSubqueryExpression), + "Unexpected expression used for subquery expression"); this.input = input; this.subquery = subquery; @@ -85,13 +87,11 @@ public ApplyNode( this.correlation = ImmutableList.copyOf(correlation); } - /** - * @return true when subquery is scalar and it's output symbols are directly mapped to ApplyNode output symbols - */ - public boolean isResolvedScalarSubquery() + private static boolean isSupportedSubqueryExpression(Expression expression) { - return isScalar(subquery) && subqueryAssignments.getExpressions().stream() - .allMatch(expression -> expression instanceof SymbolReference); + return expression instanceof InPredicate || + expression instanceof ExistsPredicate || + expression instanceof QuantifiedComparisonExpression; } @JsonProperty("input") @@ -135,7 +135,7 @@ public List getOutputSymbols() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitApply(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java index a81b462d43790..38bc7e12e855f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignUniqueId.java @@ -69,7 +69,7 @@ public Symbol getIdColumn() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitAssignUniqueId(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java index 347a1c5b2940b..6ef6a181e974c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DeleteNode.java @@ -83,7 +83,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitDelete(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java index a5014716fba67..d29c5c4b2b3f7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/DistinctLimitNode.java @@ -25,7 +25,6 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Predicates.not; import static java.util.Objects.requireNonNull; @Immutable @@ -35,6 +34,7 @@ public class DistinctLimitNode private final PlanNode source; private final long limit; private final boolean partial; + private final List distinctSymbols; private final Optional hashSymbol; @JsonCreator @@ -43,6 +43,7 @@ public DistinctLimitNode( @JsonProperty("source") PlanNode source, @JsonProperty("limit") long limit, @JsonProperty("partial") boolean partial, + @JsonProperty("distinctSymbols") List distinctSymbols, @JsonProperty("hashSymbol") Optional hashSymbol) { super(id); @@ -50,7 +51,9 @@ public DistinctLimitNode( checkArgument(limit >= 0, "limit must be greater than or equal to zero"); this.limit = limit; this.partial = partial; + this.distinctSymbols = ImmutableList.copyOf(distinctSymbols); this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + checkArgument(!hashSymbol.isPresent() || !distinctSymbols.contains(hashSymbol.get()), "distinctSymbols should not contain hash symbol"); } @Override @@ -83,22 +86,23 @@ public Optional getHashSymbol() return hashSymbol; } + @JsonProperty public List getDistinctSymbols() { - if (hashSymbol.isPresent()) { - return ImmutableList.copyOf(Iterables.filter(getOutputSymbols(), not(hashSymbol.get()::equals))); - } - return getOutputSymbols(); + return distinctSymbols; } @Override public List getOutputSymbols() { - return source.getOutputSymbols(); + ImmutableList.Builder outputSymbols = ImmutableList.builder(); + outputSymbols.addAll(distinctSymbols); + hashSymbol.ifPresent(outputSymbols::add); + return outputSymbols.build(); } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitDistinctLimit(this, context); } @@ -106,6 +110,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new DistinctLimitNode(getId(), Iterables.getOnlyElement(newChildren), limit, partial, hashSymbol); + return new DistinctLimitNode(getId(), Iterables.getOnlyElement(newChildren), limit, partial, distinctSymbols, hashSymbol); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java index 3f137334e416a..d17e5c82decd5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java @@ -60,7 +60,7 @@ public PlanNode getSource() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitEnforceSingleRow(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java index b849fe1138b8c..962a422908e88 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExceptNode.java @@ -35,7 +35,7 @@ public ExceptNode( } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitExcept(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java index 184d58c0b8f02..48b49c4b66be0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java @@ -31,6 +31,7 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static com.facebook.presto.util.MoreLists.listOfListsCopy; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -88,13 +89,13 @@ public ExchangeNode( checkArgument(scope != LOCAL || partitioningScheme.getPartitioning().getArguments().stream().allMatch(ArgumentBinding::isVariable), "local exchanges do not support constant partition function arguments"); - checkArgument(scope != REMOTE || type == Type.REPARTITION || !partitioningScheme.isReplicateNulls(), "Only REPARTITION can remotely replicate nulls"); + checkArgument(scope != REMOTE || type == Type.REPARTITION || !partitioningScheme.isReplicateNullsAndAny(), "Only REPARTITION can replicate remotely"); this.type = type; this.sources = sources; this.scope = scope; this.partitioningScheme = partitioningScheme; - this.inputs = ImmutableList.copyOf(inputs); + this.inputs = listOfListsCopy(inputs); } public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumns) @@ -102,7 +103,7 @@ public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanN return partitionedExchange(id, scope, child, partitioningColumns, hashColumns, false); } - public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumns, boolean nullsReplicated) + public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanNode child, List partitioningColumns, Optional hashColumns, boolean replicateNullsAndAny) { return partitionedExchange( id, @@ -112,7 +113,7 @@ public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanN Partitioning.create(FIXED_HASH_DISTRIBUTION, partitioningColumns), child.getOutputSymbols(), hashColumns, - nullsReplicated, + replicateNullsAndAny, Optional.empty())); } @@ -190,7 +191,7 @@ public List> getInputs() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitExchange(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java index 5911112bef032..5d5da31e29544 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java @@ -68,7 +68,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitExplainAnalyze(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java index 969f7f6feef99..b1773f1a5df3c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/FilterNode.java @@ -67,7 +67,7 @@ public PlanNode getSource() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitFilter(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java index e7af74e63d345..6b2e2feb447a8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.facebook.presto.util.MoreLists.listOfListsCopy; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toSet; @@ -60,7 +61,7 @@ public GroupIdNode(@JsonProperty("id") PlanNodeId id, { super(id); this.source = requireNonNull(source); - this.groupingSets = ImmutableList.copyOf(requireNonNull(groupingSets)); + this.groupingSets = listOfListsCopy(requireNonNull(groupingSets, "groupingSets is null")); this.groupingSetMappings = ImmutableMap.copyOf(requireNonNull(groupingSetMappings)); this.argumentMappings = ImmutableMap.copyOf(requireNonNull(argumentMappings)); this.groupIdSymbol = requireNonNull(groupIdSymbol); @@ -117,7 +118,7 @@ public Symbol getGroupIdSymbol() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitGroupId(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java index 956216d546555..b20d8e2c2e0da 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexJoinNode.java @@ -126,7 +126,7 @@ public List getOutputSymbols() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitIndexJoin(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java index 4fd03bd825cd5..e1c7eedff9b42 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IndexSourceNode.java @@ -119,7 +119,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitIndexSource(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java index 408cd41a4cd5a..a9ee23f533939 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/IntersectNode.java @@ -37,7 +37,7 @@ public IntersectNode( } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitIntersect(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java index 51491db326f5b..3ecc8862ae35b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java @@ -14,11 +14,14 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Join; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import javax.annotation.concurrent.Immutable; @@ -27,6 +30,7 @@ import java.util.Optional; import java.util.stream.Stream; +import static com.facebook.presto.sql.planner.SortExpressionExtractor.extractSortExpression; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -91,6 +95,9 @@ public JoinNode(@JsonProperty("id") PlanNodeId id, .build(); checkArgument(inputSymbols.containsAll(outputSymbols), "Left and right join inputs do not contain all output symbols"); checkArgument(!isCrossJoin() || inputSymbols.equals(outputSymbols), "Cross join does not support output symbols pruning or reordering"); + + checkArgument(!(criteria.isEmpty() && leftHashSymbol.isPresent()), "Left hash symbol is only valid in an equijoin"); + checkArgument(!(criteria.isEmpty() && rightHashSymbol.isPresent()), "Right hash symbol is only valid in an equijoin"); } public enum DistributionType @@ -168,6 +175,11 @@ public Optional getFilter() return filter; } + public Optional getSortExpression() + { + return filter.map(filter -> extractSortExpression(ImmutableSet.copyOf(right.getOutputSymbols()), filter).orElse(null)); + } + @JsonProperty("leftHashSymbol") public Optional getLeftHashSymbol() { @@ -200,7 +212,7 @@ public Optional getDistributionType() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitJoin(this, context); } @@ -247,6 +259,11 @@ public Symbol getRight() return right; } + public ComparisonExpression toExpression() + { + return new ComparisonExpression(ComparisonExpressionType.EQUAL, left.toSymbolReference(), right.toSymbolReference()); + } + @Override public boolean equals(Object obj) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java new file mode 100644 index 0000000000000..e294b577b8ab1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java @@ -0,0 +1,127 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.sql.planner.Symbol; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * For every row from {@link #input} a {@link #subquery} relation is calculated. + * Then input row is cross joined with subquery relation and returned as a result. + * + * INNER - does not return any row for input row when subquery relation is empty + * LEFT - does return input completed with NULL values when subquery relation is empty + */ +@Immutable +public class LateralJoinNode + extends PlanNode +{ + public enum Type + { + INNER, + LEFT + } + + private final PlanNode input; + private final PlanNode subquery; + + /** + * Correlation symbols, returned from input (outer plan) used in subquery (inner plan) + */ + private final List correlation; + private final Type type; + + @JsonCreator + public LateralJoinNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("input") PlanNode input, + @JsonProperty("subquery") PlanNode subquery, + @JsonProperty("correlation") List correlation, + @JsonProperty("type") Type type) + { + super(id); + requireNonNull(input, "input is null"); + requireNonNull(subquery, "right is null"); + requireNonNull(correlation, "correlation is null"); + + checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); + + this.input = input; + this.subquery = subquery; + this.correlation = ImmutableList.copyOf(correlation); + this.type = type; + } + + @JsonProperty("input") + public PlanNode getInput() + { + return input; + } + + @JsonProperty("subquery") + public PlanNode getSubquery() + { + return subquery; + } + + @JsonProperty("correlation") + public List getCorrelation() + { + return correlation; + } + + @JsonProperty("type") + public Type getType() + { + return type; + } + + @Override + public List getSources() + { + return ImmutableList.of(input, subquery); + } + + @Override + @JsonProperty("outputSymbols") + public List getOutputSymbols() + { + return ImmutableList.builder() + .addAll(input.getOutputSymbols()) + .addAll(subquery.getOutputSymbols()) + .build(); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); + return new LateralJoinNode(getId(), newChildren.get(0), newChildren.get(1), correlation, type); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitLateralJoin(this, context); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java index 22e94460c44ec..feb87a876f655 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/LimitNode.java @@ -82,7 +82,7 @@ public List getOutputSymbols() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitLimit(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java index 66a8d496ed4a8..1dbfc978d7057 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MarkDistinctNode.java @@ -90,7 +90,7 @@ public Optional getHashSymbol() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitMarkDistinct(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java index ecdb7120d5bbf..9c86052b2fe54 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/MetadataDeleteNode.java @@ -79,7 +79,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitMetadataDelete(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java index 8eb08dd6eb554..fc368cc4dc367 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/OutputNode.java @@ -77,7 +77,7 @@ public PlanNode getSource() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitOutput(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java index 22a120157f2eb..92f1764ed2f7e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanNode.java @@ -60,6 +60,7 @@ @JsonSubTypes.Type(value = ExplainAnalyzeNode.class, name = "explainAnalyze"), @JsonSubTypes.Type(value = ApplyNode.class, name = "apply"), @JsonSubTypes.Type(value = AssignUniqueId.class, name = "assignUniqueId"), + @JsonSubTypes.Type(value = LateralJoinNode.class, name = "lateralJoin"), }) public abstract class PlanNode { @@ -83,7 +84,7 @@ public PlanNodeId getId() public abstract PlanNode replaceChildren(List newChildren); - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitPlan(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanRewriter.java index 353e597993424..46685420275ab 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanRewriter.java @@ -23,7 +23,7 @@ import static java.util.Objects.requireNonNull; public abstract class PlanRewriter - extends PlanVisitor, PlanRewriter.Result

> + extends PlanVisitor, PlanRewriter.RewriteContext> { public static Result

rewriteWith(PlanRewriter rewriter, PlanNode node) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanVisitor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanVisitor.java index 489263071a5cc..5e1a21b611339 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanVisitor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/PlanVisitor.java @@ -13,7 +13,9 @@ */ package com.facebook.presto.sql.planner.plan; -public abstract class PlanVisitor +import com.facebook.presto.sql.planner.iterative.GroupReference; + +public abstract class PlanVisitor { protected abstract R visitPlan(PlanNode node, C context); @@ -186,4 +188,14 @@ public R visitAssignUniqueId(AssignUniqueId node, C context) { return visitPlan(node, context); } + + public R visitGroupReference(GroupReference node, C context) + { + return visitPlan(node, context); + } + + public R visitLateralJoin(LateralJoinNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java index 2da31a917d5b9..9d3d763089159 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ProjectNode.java @@ -87,7 +87,7 @@ public boolean isIdentity() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitProject(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java index 97751f0070bf3..b4048a6eef875 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java @@ -71,7 +71,7 @@ public List getSourceFragmentIds() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitRemoteSource(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java index 754c52d65fabf..f720d174e314b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java @@ -104,7 +104,7 @@ public Optional getHashSymbol() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitRowNumber(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java index 2e7b126186e74..b113212522c24 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java @@ -101,7 +101,7 @@ public List getOutputSymbols() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitSample(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java index 70e68d71a83b6..36cdc9a1a2d93 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SemiJoinNode.java @@ -134,7 +134,7 @@ public List getOutputSymbols() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitSemiJoin(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java index 1f01b4c8abfb3..8ee1972bd3569 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java @@ -20,7 +20,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; public abstract class SimplePlanRewriter - extends PlanVisitor, PlanNode> + extends PlanVisitor> { public static PlanNode rewriteWith(SimplePlanRewriter rewriter, PlanNode node) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java index 62b63e54e1d85..db7ba03373593 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/SortNode.java @@ -83,7 +83,7 @@ public Map getOrderings() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitSort(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java index 3eee68615d9eb..86d4bb2f29f90 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableFinishNode.java @@ -76,7 +76,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitTableFinish(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java index 387711fb0c209..95d0f088d9b69 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableScanNode.java @@ -130,7 +130,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitTableScan(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java index d3cb313d3157e..83c1c2f9e6d7c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterNode.java @@ -115,7 +115,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitTableWriter(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java index b6b87ca9a3165..a2fcfddd88107 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNNode.java @@ -17,7 +17,6 @@ import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -29,17 +28,24 @@ import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.util.Failures.checkCondition; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @Immutable public class TopNNode extends PlanNode { + public enum Step { + SINGLE, + PARTIAL, + FINAL + } + private final PlanNode source; private final long count; private final List orderBy; private final Map orderings; - private final boolean partial; + private final Step step; @JsonCreator public TopNNode(@JsonProperty("id") PlanNodeId id, @@ -47,22 +53,22 @@ public TopNNode(@JsonProperty("id") PlanNodeId id, @JsonProperty("count") long count, @JsonProperty("orderBy") List orderBy, @JsonProperty("orderings") Map orderings, - @JsonProperty("partial") boolean partial) + @JsonProperty("step") Step step) { super(id); requireNonNull(source, "source is null"); - Preconditions.checkArgument(count >= 0, "count must be positive"); + checkArgument(count >= 0, "count must be positive"); checkCondition(count <= Integer.MAX_VALUE, NOT_SUPPORTED, "ORDER BY LIMIT > %s is not supported", Integer.MAX_VALUE); requireNonNull(orderBy, "orderBy is null"); - Preconditions.checkArgument(!orderBy.isEmpty(), "orderBy is empty"); - Preconditions.checkArgument(orderings.size() == orderBy.size(), "orderBy and orderings sizes don't match"); + checkArgument(!orderBy.isEmpty(), "orderBy is empty"); + checkArgument(orderings.size() == orderBy.size(), "orderBy and orderings sizes don't match"); this.source = source; this.count = count; this.orderBy = ImmutableList.copyOf(orderBy); this.orderings = ImmutableMap.copyOf(orderings); - this.partial = partial; + this.step = requireNonNull(step, "step is null"); } @Override @@ -101,14 +107,14 @@ public Map getOrderings() return orderings; } - @JsonProperty("partial") - public boolean isPartial() + @JsonProperty("step") + public Step getStep() { - return partial; + return step; } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitTopN(this, context); } @@ -116,6 +122,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new TopNNode(getId(), Iterables.getOnlyElement(newChildren), count, orderBy, orderings, partial); + return new TopNNode(getId(), Iterables.getOnlyElement(newChildren), count, orderBy, orderings, step); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java index 9f3bfeb97d68d..bc672cabad06d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java @@ -135,7 +135,7 @@ public Optional getHashSymbol() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitTopNRowNumber(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java index d04a9cc261642..ba4e9c45e29fb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnionNode.java @@ -37,7 +37,7 @@ public UnionNode( } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitUnion(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java index afe7686f67d0e..ed1f6e979c122 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/UnnestNode.java @@ -101,7 +101,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitUnnest(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java index 6c3cb64852843..619db0952d7c3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ValuesNode.java @@ -23,6 +23,7 @@ import java.util.List; +import static com.facebook.presto.util.MoreLists.listOfListsCopy; import static com.google.common.base.Preconditions.checkArgument; @Immutable @@ -39,7 +40,7 @@ public ValuesNode(@JsonProperty("id") PlanNodeId id, { super(id); this.outputSymbols = ImmutableList.copyOf(outputSymbols); - this.rows = ImmutableList.copyOf(rows); + this.rows = listOfListsCopy(rows); for (List row : rows) { checkArgument(row.size() == outputSymbols.size() || row.size() == 0, @@ -67,7 +68,7 @@ public List getSources() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitValues(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java index e22dadd33a913..137a4e66b6578 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java @@ -154,7 +154,7 @@ public int getPreSortedOrderPrefix() } @Override - public R accept(PlanVisitor visitor, C context) + public R accept(PlanVisitor visitor, C context) { return visitor.visitWindow(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java index 8f1262d3224ed..3356b2d5a39de 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java @@ -29,7 +29,7 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.stream.Collectors.toMap; -class PlanNodeStats +public class PlanNodeStats { private final PlanNodeId planNodeId; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java index 13c3a5eb0b06e..382e7e4a30b61 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java @@ -43,7 +43,7 @@ public class PlanNodeStatsSummarizer { private PlanNodeStatsSummarizer() {} - static Map aggregatePlanNodeStats(StageInfo stageInfo) + public static Map aggregatePlanNodeStats(StageInfo stageInfo) { Map aggregatedStats = new HashMap<>(); List planNodeStats = stageInfo.getTasks().stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 46ac7f0d4f866..d3971ca45189a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -14,6 +14,8 @@ package com.facebook.presto.sql.planner.planPrinter; import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.execution.StageInfo; import com.facebook.presto.execution.StageStats; import com.facebook.presto.metadata.Metadata; @@ -28,6 +30,7 @@ import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.FunctionInvoker; import com.facebook.presto.sql.planner.Partitioning; @@ -35,7 +38,9 @@ import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.SubPlan; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; @@ -52,6 +57,7 @@ import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; @@ -107,6 +113,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static com.facebook.presto.cost.PlanNodeCost.UNKNOWN_COST; import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.planner.DomainUtils.simplifyDomain; @@ -117,6 +124,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Double.isFinite; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -128,34 +136,38 @@ public class PlanPrinter private final Metadata metadata; private final Optional> stats; - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, Session sesion) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session sesion) { - this(plan, types, metadata, sesion, 0); + this(plan, types, metadata, costCalculator, sesion, 0); } - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, Session session, int indent) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent) { requireNonNull(plan, "plan is null"); requireNonNull(types, "types is null"); requireNonNull(metadata, "metadata is null"); + requireNonNull(costCalculator, "costCalculator is null"); this.metadata = metadata; this.stats = Optional.empty(); - Visitor visitor = new Visitor(types, session); + Map costs = costCalculator.calculateCostForPlan(session, types, plan); + Visitor visitor = new Visitor(types, costs, session); plan.accept(visitor, indent); } - private PlanPrinter(PlanNode plan, Map types, Metadata metadata, Session session, Map stats, int indent) + private PlanPrinter(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent) { requireNonNull(plan, "plan is null"); requireNonNull(types, "types is null"); requireNonNull(metadata, "metadata is null"); + requireNonNull(costCalculator, "costCalculator is null"); this.metadata = metadata; this.stats = Optional.of(stats); - Visitor visitor = new Visitor(types, session); + Map costs = costCalculator.calculateCostForPlan(session, types, plan); + Visitor visitor = new Visitor(types, costs, session); plan.accept(visitor, indent); } @@ -165,22 +177,22 @@ public String toString() return output.toString(); } - public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, Session session) + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session) { - return new PlanPrinter(plan, types, metadata, session).toString(); + return new PlanPrinter(plan, types, metadata, costCalculator, session).toString(); } - public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, Session session, int indent) + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, int indent) { - return new PlanPrinter(plan, types, metadata, session, indent).toString(); + return new PlanPrinter(plan, types, metadata, costCalculator, session, indent).toString(); } - public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, Session session, Map stats, int indent) + public static String textLogicalPlan(PlanNode plan, Map types, Metadata metadata, CostCalculator costCalculator, Session session, Map stats, int indent) { - return new PlanPrinter(plan, types, metadata, session, stats, indent).toString(); + return new PlanPrinter(plan, types, metadata, costCalculator, session, stats, indent).toString(); } - public static String textDistributedPlan(StageInfo outputStageInfo, Metadata metadata, Session session) + public static String textDistributedPlan(StageInfo outputStageInfo, Metadata metadata, CostCalculator costCalculator, Session session) { StringBuilder builder = new StringBuilder(); List allStages = outputStageInfo.getSubStages().stream() @@ -188,23 +200,23 @@ public static String textDistributedPlan(StageInfo outputStageInfo, Metadata met .collect(toImmutableList()); for (StageInfo stageInfo : allStages) { Map aggregatedStats = aggregatePlanNodeStats(stageInfo); - builder.append(formatFragment(metadata, session, stageInfo.getPlan(), Optional.of(stageInfo.getStageStats()), Optional.of(aggregatedStats))); + builder.append(formatFragment(metadata, costCalculator, session, stageInfo.getPlan(), Optional.of(stageInfo.getStageStats()), Optional.of(aggregatedStats))); } return builder.toString(); } - public static String textDistributedPlan(SubPlan plan, Metadata metadata, Session session) + public static String textDistributedPlan(SubPlan plan, Metadata metadata, CostCalculator costCalculator, Session session) { StringBuilder builder = new StringBuilder(); for (PlanFragment fragment : plan.getAllFragments()) { - builder.append(formatFragment(metadata, session, fragment, Optional.empty(), Optional.empty())); + builder.append(formatFragment(metadata, costCalculator, session, fragment, Optional.empty(), Optional.empty())); } return builder.toString(); } - private static String formatFragment(Metadata metadata, Session session, PlanFragment fragment, Optional stageStats, Optional> planNodeStats) + private static String formatFragment(Metadata metadata, CostCalculator costCalculator, Session session, PlanFragment fragment, Optional stageStats, Optional> planNodeStats) { StringBuilder builder = new StringBuilder(); builder.append(format("Fragment %s [%s]\n", @@ -213,7 +225,7 @@ private static String formatFragment(Metadata metadata, Session session, PlanFra if (stageStats.isPresent()) { builder.append(indentString(1)) - .append(format("Cost: CPU %s, Input: %s (%s), Output: %s (%s)\n", + .append(format("CPU: %s, Input: %s (%s), Output: %s (%s)\n", stageStats.get().getTotalCpuTime(), formatPositions(stageStats.get().getProcessedInputPositions()), stageStats.get().getProcessedInputDataSize(), @@ -226,7 +238,7 @@ private static String formatFragment(Metadata metadata, Session session, PlanFra .append(format("Output layout: [%s]\n", Joiner.on(", ").join(partitioningScheme.getOutputLayout()))); - boolean replicateNulls = partitioningScheme.isReplicateNulls(); + boolean replicateNullsAndAny = partitioningScheme.isReplicateNullsAndAny(); List arguments = partitioningScheme.getPartitioning().getArguments().stream() .map(argument -> { if (argument.isConstant()) { @@ -238,8 +250,8 @@ private static String formatFragment(Metadata metadata, Session session, PlanFra }) .collect(toImmutableList()); builder.append(indentString(1)); - if (replicateNulls) { - builder.append(format("Output partitioning: %s (replicate nulls) [%s]%s\n", + if (replicateNullsAndAny) { + builder.append(format("Output partitioning: %s (replicate nulls and any) [%s]%s\n", partitioningScheme.getPartitioning().getHandle(), Joiner.on(", ").join(arguments), formatHash(partitioningScheme.getHashColumn()))); @@ -252,11 +264,11 @@ private static String formatFragment(Metadata metadata, Session session, PlanFra } if (stageStats.isPresent()) { - builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, session, planNodeStats.get(), 1)) + builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, planNodeStats.get(), 1)) .append("\n"); } else { - builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, session, 1)) + builder.append(textLogicalPlan(fragment.getRoot(), fragment.getSymbols(), metadata, costCalculator, session, 1)) .append("\n"); } @@ -331,7 +343,7 @@ private void printStats(int indent, PlanNodeId planNodeId, boolean printInput, b double fraction = 100.0d * nodeStats.getPlanNodeWallTime().toMillis() / totalMillis; output.append(indentString(indent)); - output.append("Cost: " + formatDouble(fraction) + "%"); + output.append("CPU fraction: " + formatDouble(fraction) + "%"); if (printInput) { output.append(format(", Input: %s (%s)", formatPositions(nodeStats.getPlanNodeInputPositions()), @@ -438,15 +450,17 @@ private static String indentString(int indent) } private class Visitor - extends PlanVisitor + extends PlanVisitor { private final Map types; + private final Map costs; private final Session session; @SuppressWarnings("AssignmentToCollectionOrArrayFieldFromParameter") - public Visitor(Map types, Session session) + public Visitor(Map types, Map costs, Session session) { this.types = types; + this.costs = costs; this.session = session; } @@ -454,6 +468,7 @@ public Visitor(Map types, Session session) public Void visitExplainAnalyze(ExplainAnalyzeNode node, Integer indent) { print(indent, "- ExplainAnalyze => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -463,9 +478,7 @@ public Void visitJoin(JoinNode node, Integer indent) { List joinExpressions = new ArrayList<>(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { - joinExpressions.add(new ComparisonExpression(ComparisonExpressionType.EQUAL, - clause.getLeft().toSymbolReference(), - clause.getRight().toSymbolReference())); + joinExpressions.add(clause.toExpression()); } node.getFilter().ifPresent(expression -> joinExpressions.add(expression)); @@ -481,6 +494,8 @@ public Void visitJoin(JoinNode node, Integer indent) formatOutputs(node.getOutputSymbols())); } + node.getSortExpression().ifPresent(expression -> print(indent + 2, "SortExpression[%s]", expression)); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getLeft().accept(this, indent + 1); node.getRight().accept(this, indent + 1); @@ -496,6 +511,7 @@ public Void visitSemiJoin(SemiJoinNode node, Integer indent) node.getFilteringSourceJoinSymbol(), formatHash(node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getSource().accept(this, indent + 1); node.getFilteringSource().accept(this, indent + 1); @@ -507,6 +523,7 @@ public Void visitSemiJoin(SemiJoinNode node, Integer indent) public Void visitIndexSource(IndexSourceNode node, Integer indent) { print(indent, "- IndexSource[%s, lookup = %s] => [%s]", node.getIndexHandle(), node.getLookupSymbols(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry entry : node.getAssignments().entrySet()) { if (node.getOutputSymbols().contains(entry.getKey())) { @@ -531,6 +548,7 @@ public Void visitIndexJoin(IndexJoinNode node, Integer indent) Joiner.on(" AND ").join(joinExpressions), formatHash(node.getProbeHashSymbol(), node.getIndexHashSymbol()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getProbeSource().accept(this, indent + 1); node.getIndexSource().accept(this, indent + 1); @@ -542,6 +560,7 @@ public Void visitIndexJoin(IndexJoinNode node, Integer indent) public Void visitLimit(LimitNode node, Integer indent) { print(indent, "- Limit%s[%s] => [%s]", node.isPartial() ? "Partial" : "", node.getCount(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -554,6 +573,7 @@ public Void visitDistinctLimit(DistinctLimitNode node, Integer indent) node.getLimit(), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -571,14 +591,15 @@ public Void visitAggregation(AggregationNode node, Integer indent) } print(indent, "- Aggregate%s%s%s => [%s]", type, key, formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); - for (Map.Entry entry : node.getAggregations().entrySet()) { - if (node.getMasks().containsKey(entry.getKey())) { - print(indent + 2, "%s := %s (mask = %s)", entry.getKey(), entry.getValue(), node.getMasks().get(entry.getKey())); + for (Map.Entry entry : node.getAggregations().entrySet()) { + if (entry.getValue().getMask().isPresent()) { + print(indent + 2, "%s := %s (mask = %s)", entry.getKey(), entry.getValue().getCall(), entry.getValue().getMask().get()); } else { - print(indent + 2, "%s := %s", entry.getKey(), entry.getValue()); + print(indent + 2, "%s := %s", entry.getKey(), entry.getValue().getCall()); } } @@ -596,6 +617,7 @@ public Void visitGroupId(GroupIdNode node, Integer indent) .collect(Collectors.toList()); print(indent, "- GroupId%s => [%s]", inputGroupingSetSymbols, formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry mapping : node.getGroupingSetMappings().entrySet()) { @@ -617,6 +639,7 @@ public Void visitMarkDistinct(MarkDistinctNode node, Integer indent) formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -664,6 +687,7 @@ public Void visitWindow(WindowNode node, Integer indent) } print(indent, "- Window[%s]%s => [%s]", Joiner.on(", ").join(args), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); for (Map.Entry entry : node.getWindowFunctions().entrySet()) { @@ -694,6 +718,7 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Integer indent) node.getMaxRowCountPerPartition(), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); print(indent + 2, "%s := %s", node.getRowNumberSymbol(), "row_number()"); @@ -713,7 +738,11 @@ public Void visitRowNumber(RowNumberNode node, Integer indent) args.add(format("limit = %s", node.getMaxRowCountPerPartition().get())); } - print(indent, "- RowNumber[%s]%s => [%s]", Joiner.on(", ").join(args), formatHash(node.getHashSymbol()), formatOutputs(node.getOutputSymbols())); + print(indent, "- RowNumber[%s]%s => [%s]", + Joiner.on(", ").join(args), + formatHash(node.getHashSymbol()), + formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); print(indent + 2, "%s := %s", node.getRowNumberSymbol(), "row_number()"); @@ -725,6 +754,7 @@ public Void visitTableScan(TableScanNode node, Integer indent) { TableHandle table = node.getTable(); print(indent, "- TableScan[%s, originalConstraint = %s] => [%s]", table, node.getOriginalConstraint(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); printTableScanInfo(node, indent); @@ -735,6 +765,7 @@ public Void visitTableScan(TableScanNode node, Integer indent) public Void visitValues(ValuesNode node, Integer indent) { print(indent, "- Values => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); for (List row : node.getRows()) { print(indent + 2, "(" + Joiner.on(", ").join(row) + ")"); @@ -760,7 +791,8 @@ public Void visitProject(ProjectNode node, Integer indent) private Void visitScanFilterAndProjectInfo( PlanNodeId planNodeId, - Optional filterNode, Optional projectNode, + Optional filterNode, + Optional projectNode, int indent) { checkState(projectNode.isPresent() || filterNode.isPresent()); @@ -813,6 +845,11 @@ private Void visitScanFilterAndProjectInfo( format = operatorName + format; print(indent, format, arguments); + printCost(indent + 2, + Stream.of(scanNode, filterNode, projectNode) + .filter(Optional::isPresent) + .map(Optional::get) + .toArray(PlanNode[]::new)); printStats(indent + 2, planNodeId, true, true); if (projectNode.isPresent()) { @@ -876,6 +913,7 @@ private void printTableScanInfo(TableScanNode node, int indent) public Void visitUnnest(UnnestNode node, Integer indent) { print(indent, "- Unnest [replicate=%s, unnest=%s] => [%s]", formatOutputs(node.getReplicateSymbols()), formatOutputs(node.getUnnestSymbols().keySet()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -885,6 +923,7 @@ public Void visitUnnest(UnnestNode node, Integer indent) public Void visitOutput(OutputNode node, Integer indent) { print(indent, "- Output[%s] => [%s]", Joiner.on(", ").join(node.getColumnNames()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); @@ -903,6 +942,7 @@ public Void visitTopN(TopNNode node, Integer indent) Iterable keys = Iterables.transform(node.getOrderBy(), input -> input + " " + node.getOrderings().get(input)); print(indent, "- TopN[%s by (%s)] => [%s]", node.getCount(), Joiner.on(", ").join(keys), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -913,6 +953,7 @@ public Void visitSort(SortNode node, Integer indent) Iterable keys = Iterables.transform(node.getOrderBy(), input -> input + " " + node.getOrderings().get(input)); print(indent, "- Sort[%s] => [%s]", Joiner.on(", ").join(keys), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } @@ -921,6 +962,7 @@ public Void visitSort(SortNode node, Integer indent) public Void visitRemoteSource(RemoteSourceNode node, Integer indent) { print(indent, "- RemoteSource[%s] => [%s]", Joiner.on(',').join(node.getSourceFragmentIds()), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return null; @@ -930,6 +972,7 @@ public Void visitRemoteSource(RemoteSourceNode node, Integer indent) public Void visitUnion(UnionNode node, Integer indent) { print(indent, "- Union => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -939,6 +982,7 @@ public Void visitUnion(UnionNode node, Integer indent) public Void visitIntersect(IntersectNode node, Integer indent) { print(indent, "- Intersect => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -948,6 +992,7 @@ public Void visitIntersect(IntersectNode node, Integer indent) public Void visitExcept(ExceptNode node, Integer indent) { print(indent, "- Except => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -957,6 +1002,7 @@ public Void visitExcept(ExceptNode node, Integer indent) public Void visitTableWriter(TableWriterNode node, Integer indent) { print(indent, "- TableWriter => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); @@ -971,6 +1017,7 @@ public Void visitTableWriter(TableWriterNode node, Integer indent) public Void visitTableFinish(TableFinishNode node, Integer indent) { print(indent, "- TableCommit[%s] => [%s]", node.getTarget(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -980,6 +1027,7 @@ public Void visitTableFinish(TableFinishNode node, Integer indent) public Void visitSample(SampleNode node, Integer indent) { print(indent, "- Sample[%s: %s] => [%s]", node.getSampleType(), node.getSampleRatio(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -991,7 +1039,7 @@ public Void visitExchange(ExchangeNode node, Integer indent) if (node.getScope() == Scope.LOCAL) { print(indent, "- LocalExchange[%s%s]%s (%s) => %s", node.getPartitioningScheme().getPartitioning().getHandle(), - node.getPartitioningScheme().isReplicateNulls() ? " - REPLICATE NULLS" : "", + node.getPartitioningScheme().isReplicateNullsAndAny() ? " - REPLICATE NULLS AND ANY" : "", formatHash(node.getPartitioningScheme().getHashColumn()), Joiner.on(", ").join(node.getPartitioningScheme().getPartitioning().getArguments()), formatOutputs(node.getOutputSymbols())); @@ -1000,10 +1048,11 @@ public Void visitExchange(ExchangeNode node, Integer indent) print(indent, "- %sExchange[%s%s]%s => %s", UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, node.getScope().toString()), node.getType(), - node.getPartitioningScheme().isReplicateNulls() ? " - REPLICATE NULLS" : "", + node.getPartitioningScheme().isReplicateNullsAndAny() ? " - REPLICATE NULLS AND ANY" : "", formatHash(node.getPartitioningScheme().getHashColumn()), formatOutputs(node.getOutputSymbols())); } + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1013,6 +1062,7 @@ public Void visitExchange(ExchangeNode node, Integer indent) public Void visitDelete(DeleteNode node, Integer indent) { print(indent, "- Delete[%s] => [%s]", node.getTarget(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1022,6 +1072,7 @@ public Void visitDelete(DeleteNode node, Integer indent) public Void visitMetadataDelete(MetadataDeleteNode node, Integer indent) { print(indent, "- MetadataDelete[%s] => [%s]", node.getTarget(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1031,6 +1082,7 @@ public Void visitMetadataDelete(MetadataDeleteNode node, Integer indent) public Void visitEnforceSingleRow(EnforceSingleRowNode node, Integer indent) { print(indent, "- Scalar => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); @@ -1040,21 +1092,40 @@ public Void visitEnforceSingleRow(EnforceSingleRowNode node, Integer indent) public Void visitAssignUniqueId(AssignUniqueId node, Integer indent) { print(indent, "- AssignUniqueId => [%s]", formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); return processChildren(node, indent + 1); } + @Override + public Void visitGroupReference(GroupReference node, Integer indent) + { + print(indent, "- GroupReference[%s] => [%s]", node.getGroupId(), formatOutputs(node.getOutputSymbols())); + + return processChildren(node, indent + 1); + } + @Override public Void visitApply(ApplyNode node, Integer indent) { print(indent, "- Apply[%s] => [%s]", node.getCorrelation(), formatOutputs(node.getOutputSymbols())); + printCost(indent + 2, node); printStats(indent + 2, node.getId()); printAssignments(node.getSubqueryAssignments(), indent + 4); return processChildren(node, indent + 1); } + @Override + public Void visitLateralJoin(LateralJoinNode node, Integer indent) + { + print(indent, "- Lateral[%s] => [%s]", node.getCorrelation(), formatOutputs(node.getOutputSymbols())); + printStats(indent + 2, node.getId()); + + return processChildren(node, indent + 1); + } + @Override protected Void visitPlan(PlanNode node, Integer indent) { @@ -1149,6 +1220,31 @@ private String formatDomain(Domain domain) return "[" + Joiner.on(", ").join(parts.build()) + "]"; } + + private void printCost(int indent, PlanNode... nodes) + { + if (Arrays.stream(nodes).anyMatch(this::isKnownCost)) { + String costString = Joiner.on("/").join(Arrays.stream(nodes) + .map(this::formatCost) + .collect(toImmutableList())); + print(indent, "Cost: %s", costString); + } + } + + private boolean isKnownCost(PlanNode node) + { + return !UNKNOWN_COST.equals(costs.getOrDefault(node.getId(), UNKNOWN_COST)); + } + + private String formatCost(PlanNode node) + { + PlanNodeCost cost = costs.getOrDefault(node.getId(), UNKNOWN_COST); + Estimate outputRowCount = cost.getOutputRowCount(); + Estimate outputSizeInBytes = cost.getOutputSizeInBytes(); + return String.format("{rows: %s, bytes: %s}", + outputRowCount.isValueUnknown() ? "?" : String.valueOf((long) outputRowCount.getValue()), + outputSizeInBytes.isValueUnknown() ? "?" : succinctBytes((long) outputSizeInBytes.getValue())); + } } private static String formatHash(Optional... hashes) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoDuplicatePlanNodeIdsChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoDuplicatePlanNodeIdsChecker.java new file mode 100644 index 0000000000000..ff9d3923621f9 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoDuplicatePlanNodeIdsChecker.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.sanity; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.SimplePlanVisitor; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class NoDuplicatePlanNodeIdsChecker + implements PlanSanityChecker.Checker +{ + @Override + public void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types) + { + planNode.accept(new Visitor(), new HashMap<>()); + } + + private static class Visitor + extends SimplePlanVisitor> + { + @Override + protected Void visitPlan(PlanNode node, Map context) + { + context.merge(node.getId(), node, this::reportDuplicateId); + + return super.visitPlan(node, context); + } + + private PlanNode reportDuplicateId(PlanNode first, PlanNode second) + { + requireNonNull(first, "first is null"); + requireNonNull(second, "second is null"); + checkArgument(first.getId().equals(second.getId())); + + throw new IllegalStateException(format( + "Generated plan contains nodes with duplicated id %s: %s and %s", + first.getId(), + first, + second)); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoApplyNodeLeftChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryRelatedNodeLeftChecker.java similarity index 66% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoApplyNodeLeftChecker.java rename to presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryRelatedNodeLeftChecker.java index 56e4a23d99d31..576c00e969743 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoApplyNodeLeftChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryRelatedNodeLeftChecker.java @@ -20,11 +20,13 @@ import com.facebook.presto.sql.planner.SimplePlanVisitor; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import java.util.List; import java.util.Map; -public class NoApplyNodeLeftChecker +public class NoSubqueryRelatedNodeLeftChecker implements PlanSanityChecker.Checker { @Override @@ -35,11 +37,22 @@ public void validate(PlanNode plan, Session session, Metadata metadata, SqlParse @Override public Object visitApply(ApplyNode node, Object context) { - if (node.getCorrelation().isEmpty()) { - throw new IllegalArgumentException("Unsupported subquery type"); + throw subqueryLeftException(node.getCorrelation()); + } + + @Override + public Object visitLateralJoin(LateralJoinNode node, Object context) + { + throw subqueryLeftException(node.getCorrelation()); + } + + private IllegalArgumentException subqueryLeftException(List correlation) + { + if (correlation.isEmpty()) { + return new IllegalArgumentException("Unsupported subquery type"); } else { - throw new IllegalArgumentException("Unsupported correlated subquery type"); + return new IllegalArgumentException("Unsupported correlated subquery type"); } } }, null); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java index 00865fe004790..3ed74647dbfaa 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/PlanSanityChecker.java @@ -19,33 +19,54 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Multimap; -import java.util.List; import java.util.Map; /** - * It is going to be executed at the end of logical planner, to verify its correctness + * It is going to be executed to verify logical planner correctness */ public final class PlanSanityChecker { - private static final List CHECKERS = ImmutableList.of( - new ValidateDependenciesChecker(), - new TypeValidator(), - new NoSubqueryExpressionLeftChecker(), - new NoApplyNodeLeftChecker(), - new VerifyNoFilteredAggregations(), - new VerifyOnlyOneOutputNode()); + private static final Multimap CHECKERS = ImmutableListMultimap.builder() + .putAll( + Stage.INTERMEDIATE, + new ValidateDependenciesChecker(), + new NoDuplicatePlanNodeIdsChecker(), + new TypeValidator(), + new NoSubqueryExpressionLeftChecker(), + new VerifyOnlyOneOutputNode()) + .putAll( + Stage.FINAL, + new ValidateDependenciesChecker(), + new NoDuplicatePlanNodeIdsChecker(), + new TypeValidator(), + new NoSubqueryExpressionLeftChecker(), + new VerifyOnlyOneOutputNode(), + new NoSubqueryRelatedNodeLeftChecker(), + new VerifyNoFilteredAggregations()) + .build(); private PlanSanityChecker() {} - public static void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types) + public static void validateFinalPlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types) { - CHECKERS.forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types)); + CHECKERS.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types)); + } + + public static void validateIntermediatePlan(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types) + { + CHECKERS.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, sqlParser, types)); } public interface Checker { void validate(PlanNode planNode, Session session, Metadata metadata, SqlParser sqlParser, Map types); } + + private enum Stage + { + INTERMEDIATE, FINAL + }; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java index 0790285791619..9bf66bf390af3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java @@ -23,12 +23,14 @@ import com.facebook.presto.sql.planner.SimplePlanVisitor; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ListMultimap; @@ -80,11 +82,11 @@ public Void visitAggregation(AggregationNode node, Void context) switch (step) { case SINGLE: - checkFunctionSignature(node.getFunctions()); + checkFunctionSignature(node.getAggregations()); checkFunctionCall(node.getAggregations()); break; case FINAL: - checkFunctionSignature(node.getFunctions()); + checkFunctionSignature(node.getAggregations()); break; } @@ -113,7 +115,8 @@ public Void visitProject(ProjectNode node, Void context) verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); continue; } - Type actualType = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList() /* parameters already replaced */).get(entry.getValue()); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList() /* parameters already replaced */); + Type actualType = expressionTypes.get(NodeRef.of(entry.getValue())); verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); } @@ -158,21 +161,22 @@ private void checkSignature(Symbol symbol, Signature signature) private void checkCall(Symbol symbol, FunctionCall call) { Type expectedType = types.get(symbol); - Type actualType = getExpressionTypes(session, metadata, sqlParser, types, call, emptyList() /*parameters already replaced */).get(call); + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, call, emptyList() /*parameters already replaced */); + Type actualType = expressionTypes.get(NodeRef.of(call)); verifyTypeSignature(symbol, expectedType.getTypeSignature(), actualType.getTypeSignature()); } - private void checkFunctionSignature(Map functions) + private void checkFunctionSignature(Map aggregations) { - for (Map.Entry entry : functions.entrySet()) { - checkSignature(entry.getKey(), entry.getValue()); + for (Map.Entry entry : aggregations.entrySet()) { + checkSignature(entry.getKey(), entry.getValue().getSignature()); } } - private void checkFunctionCall(Map functionCalls) + private void checkFunctionCall(Map aggregations) { - for (Map.Entry entry : functionCalls.entrySet()) { - checkCall(entry.getKey(), entry.getValue()); + for (Map.Entry entry : aggregations.entrySet()) { + checkCall(entry.getKey(), entry.getValue().getCall()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 21ea2802d3f50..c75a1e9d3c460 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -17,9 +17,10 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.DeleteNode; @@ -34,12 +35,12 @@ import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; @@ -58,12 +59,10 @@ import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -92,10 +91,8 @@ public static void validate(PlanNode plan) } private static class Visitor - extends PlanVisitor, Void> + extends PlanVisitor> { - private final Map nodesById = new HashMap<>(); - @Override protected Void visitPlan(PlanNode node, Set boundSymbols) { @@ -108,8 +105,6 @@ public Void visitExplainAnalyze(ExplainAnalyzeNode node, Set boundSymbol PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - return null; } @@ -119,13 +114,11 @@ public Void visitAggregation(AggregationNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - Set inputs = createInputs(source, boundSymbols); checkDependencies(inputs, node.getGroupingKeys(), "Invalid node. Grouping key symbols (%s) not in source plan output (%s)", node.getGroupingKeys(), node.getSource().getOutputSymbols()); - for (FunctionCall call : node.getAggregations().values()) { - Set dependencies = DependencyExtractor.extractUnique(call); + for (Aggregation aggregation : node.getAggregations().values()) { + Set dependencies = SymbolsExtractor.extractUnique(aggregation.getCall()); checkDependencies(inputs, dependencies, "Invalid node. Aggregation dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); } @@ -138,8 +131,6 @@ public Void visitGroupId(GroupIdNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - checkDependencies(source.getOutputSymbols(), node.getInputSymbols(), "Invalid node. Grouping symbols (%s) not in source plan output (%s)", node.getInputSymbols(), source.getOutputSymbols()); return null; @@ -151,8 +142,6 @@ public Void visitMarkDistinct(MarkDistinctNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - checkDependencies(source.getOutputSymbols(), node.getDistinctSymbols(), "Invalid node. Mark distinct symbols (%s) not in source plan output (%s)", node.getDistinctSymbols(), source.getOutputSymbols()); return null; @@ -164,8 +153,6 @@ public Void visitWindow(WindowNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - Set inputs = createInputs(source, boundSymbols); checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols()); @@ -183,7 +170,7 @@ public Void visitWindow(WindowNode node, Set boundSymbols) checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputSymbols()); for (WindowNode.Function function : node.getWindowFunctions().values()) { - Set dependencies = DependencyExtractor.extractUnique(function.getFunctionCall()); + Set dependencies = SymbolsExtractor.extractUnique(function.getFunctionCall()); checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); } @@ -196,8 +183,6 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - Set inputs = createInputs(source, boundSymbols); checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols()); checkDependencies(inputs, node.getOrderBy(), "Invalid node. Order by symbols (%s) not in source plan output (%s)", node.getOrderBy(), node.getSource().getOutputSymbols()); @@ -211,8 +196,6 @@ public Void visitRowNumber(RowNumberNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - checkDependencies(source.getOutputSymbols(), node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols()); return null; @@ -224,12 +207,10 @@ public Void visitFilter(FilterNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - Set inputs = createInputs(source, boundSymbols); checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); - Set dependencies = DependencyExtractor.extractUnique(node.getPredicate()); + Set dependencies = SymbolsExtractor.extractUnique(node.getPredicate()); checkDependencies(inputs, dependencies, "Invalid node. Predicate dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); return null; @@ -241,8 +222,6 @@ public Void visitSample(SampleNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - return null; } @@ -252,11 +231,9 @@ public Void visitProject(ProjectNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - Set inputs = createInputs(source, boundSymbols); for (Expression expression : node.getAssignments().getExpressions()) { - Set dependencies = DependencyExtractor.extractUnique(expression); + Set dependencies = SymbolsExtractor.extractUnique(expression); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } @@ -269,8 +246,6 @@ public Void visitTopN(TopNNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - Set inputs = createInputs(source, boundSymbols); checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); checkDependencies(inputs, node.getOrderBy(), @@ -287,8 +262,6 @@ public Void visitSort(SortNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - Set inputs = createInputs(source, boundSymbols); checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols()); checkDependencies(inputs, node.getOrderBy(), "Invalid node. Order by dependencies (%s) not in source plan output (%s)", node.getOrderBy(), node.getSource().getOutputSymbols()); @@ -302,8 +275,6 @@ public Void visitOutput(OutputNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols()); return null; @@ -315,8 +286,6 @@ public Void visitLimit(LimitNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - return null; } @@ -326,7 +295,8 @@ public Void visitDistinctLimit(DistinctLimitNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); + checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols()); + return null; } @@ -336,8 +306,6 @@ public Void visitJoin(JoinNode node, Set boundSymbols) node.getLeft().accept(this, boundSymbols); node.getRight().accept(this, boundSymbols); - verifyUniqueId(node); - Set leftInputs = createInputs(node.getLeft(), boundSymbols); Set rightInputs = createInputs(node.getRight(), boundSymbols); Set allInputs = ImmutableSet.builder() @@ -351,7 +319,7 @@ public Void visitJoin(JoinNode node, Set boundSymbols) } node.getFilter().ifPresent(predicate -> { - Set predicateSymbols = DependencyExtractor.extractUnique(predicate); + Set predicateSymbols = SymbolsExtractor.extractUnique(predicate); checkArgument( allInputs.containsAll(predicateSymbols), "Symbol from filter (%s) not in sources (%s)", @@ -390,8 +358,6 @@ public Void visitSemiJoin(SemiJoinNode node, Set boundSymbols) node.getSource().accept(this, boundSymbols); node.getFilteringSource().accept(this, boundSymbols); - verifyUniqueId(node); - checkArgument(node.getSource().getOutputSymbols().contains(node.getSourceJoinSymbol()), "Symbol from semi join clause (%s) not in source (%s)", node.getSourceJoinSymbol(), node.getSource().getOutputSymbols()); checkArgument(node.getFilteringSource().getOutputSymbols().contains(node.getFilteringSourceJoinSymbol()), "Symbol from semi join clause (%s) not in filtering source (%s)", node.getSourceJoinSymbol(), node.getFilteringSource().getOutputSymbols()); @@ -411,8 +377,6 @@ public Void visitIndexJoin(IndexJoinNode node, Set boundSymbols) node.getProbeSource().accept(this, boundSymbols); node.getIndexSource().accept(this, boundSymbols); - verifyUniqueId(node); - Set probeInputs = createInputs(node.getProbeSource(), boundSymbols); Set indexSourceInputs = createInputs(node.getIndexSource(), boundSymbols); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { @@ -434,8 +398,6 @@ public Void visitIndexJoin(IndexJoinNode node, Set boundSymbols) @Override public Void visitIndexSource(IndexSourceNode node, Set boundSymbols) { - verifyUniqueId(node); - checkDependencies(node.getOutputSymbols(), node.getLookupSymbols(), "Lookup symbols must be part of output symbols"); checkDependencies(node.getAssignments().keySet(), node.getOutputSymbols(), "Assignments must contain mappings for output symbols"); @@ -445,8 +407,6 @@ public Void visitIndexSource(IndexSourceNode node, Set boundSymbols) @Override public Void visitTableScan(TableScanNode node, Set boundSymbols) { - verifyUniqueId(node); - checkArgument(node.getAssignments().keySet().containsAll(node.getOutputSymbols()), "Assignments must contain mappings for output symbols"); return null; @@ -455,7 +415,6 @@ public Void visitTableScan(TableScanNode node, Set boundSymbols) @Override public Void visitValues(ValuesNode node, Set boundSymbols) { - verifyUniqueId(node); return null; } @@ -465,8 +424,6 @@ public Void visitUnnest(UnnestNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); - verifyUniqueId(node); - Set required = ImmutableSet.builder() .addAll(node.getReplicateSymbols()) .addAll(node.getUnnestSymbols().keySet()) @@ -480,8 +437,6 @@ public Void visitUnnest(UnnestNode node, Set boundSymbols) @Override public Void visitRemoteSource(RemoteSourceNode node, Set boundSymbols) { - verifyUniqueId(node); - return null; } @@ -496,8 +451,6 @@ public Void visitExchange(ExchangeNode node, Set boundSymbols) checkDependencies(node.getOutputSymbols(), node.getPartitioningScheme().getOutputLayout(), "EXCHANGE must provide all of the necessary symbols for partition function"); - verifyUniqueId(node); - return null; } @@ -507,8 +460,6 @@ public Void visitTableWriter(TableWriterNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - return null; } @@ -518,8 +469,6 @@ public Void visitDelete(DeleteNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - verifyUniqueId(node); - checkArgument(source.getOutputSymbols().contains(node.getRowId()), "Invalid node. Row ID symbol (%s) is not in source plan output (%s)", node.getRowId(), node.getSource().getOutputSymbols()); return null; @@ -528,8 +477,6 @@ public Void visitDelete(DeleteNode node, Set boundSymbols) @Override public Void visitMetadataDelete(MetadataDeleteNode node, Set boundSymbols) { - verifyUniqueId(node); - return null; } @@ -538,8 +485,6 @@ public Void visitTableFinish(TableFinishNode node, Set boundSymbols) { node.getSource().accept(this, boundSymbols); // visit child - verifyUniqueId(node); - return null; } @@ -557,8 +502,6 @@ private Void visitSetOperation(SetOperationNode node, Set boundSymbols) subplan.accept(this, boundSymbols); // visit child } - verifyUniqueId(node); - return null; } @@ -579,8 +522,6 @@ public Void visitEnforceSingleRow(EnforceSingleRowNode node, Set boundSy { node.getSource().accept(this, boundSymbols); // visit child - verifyUniqueId(node); - return null; } @@ -589,8 +530,6 @@ public Void visitAssignUniqueId(AssignUniqueId node, Set boundSymbols) { node.getSource().accept(this, boundSymbols); // visit child - verifyUniqueId(node); - return null; } @@ -606,7 +545,7 @@ public Void visitApply(ApplyNode node, Set boundSymbols) node.getSubquery().accept(this, subqueryCorrelation); // visit child checkDependencies(node.getInput().getOutputSymbols(), node.getCorrelation(), "APPLY input must provide all the necessary correlation symbols for subquery"); - checkDependencies(DependencyExtractor.extractUnique(node.getSubquery()), node.getCorrelation(), "not all APPLY correlation symbols are used in subquery"); + checkDependencies(SymbolsExtractor.extractUnique(node.getSubquery()), node.getCorrelation(), "not all APPLY correlation symbols are used in subquery"); ImmutableSet inputs = ImmutableSet.builder() .addAll(createInputs(node.getSubquery(), boundSymbols)) @@ -614,21 +553,34 @@ public Void visitApply(ApplyNode node, Set boundSymbols) .build(); for (Expression expression : node.getSubqueryAssignments().getExpressions()) { - Set dependencies = DependencyExtractor.extractUnique(expression); + Set dependencies = SymbolsExtractor.extractUnique(expression); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } - verifyUniqueId(node); - return null; } - private void verifyUniqueId(PlanNode node) + @Override + public Void visitLateralJoin(LateralJoinNode node, Set boundSymbols) { - PlanNodeId id = node.getId(); - checkArgument(!nodesById.containsKey(id), "Duplicate node id found %s between %s and %s", node.getId(), node, nodesById.get(id)); + Set subqueryCorrelation = ImmutableSet.builder() + .addAll(boundSymbols) + .addAll(node.getCorrelation()) + .build(); - nodesById.put(id, node); + node.getInput().accept(this, boundSymbols); // visit child + node.getSubquery().accept(this, subqueryCorrelation); // visit child + + checkDependencies( + node.getInput().getOutputSymbols(), + node.getCorrelation(), + "LATERAL input must provide all the necessary correlation symbols for subquery"); + checkDependencies( + SymbolsExtractor.extractUnique(node.getSubquery()), + node.getCorrelation(), + "not all LATERAL correlation symbols are used in subquery"); + + return null; } private static ImmutableSet createInputs(PlanNode source, Set boundSymbols) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/VerifyNoFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/VerifyNoFilteredAggregations.java index 42a1a715b91eb..253b0f7e2c103 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/VerifyNoFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/VerifyNoFilteredAggregations.java @@ -20,8 +20,8 @@ import com.facebook.presto.sql.planner.SimplePlanVisitor; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.tree.FunctionCall; import java.util.Map; @@ -42,8 +42,8 @@ public Void visitAggregation(AggregationNode node, Void context) { super.visitAggregation(node, context); - for (FunctionCall call : node.getAggregations().values()) { - if (call.getFilter().isPresent()) { + for (Aggregation aggregation : node.getAggregations().values()) { + if (aggregation.getCall().getFilter().isPresent()) { throw new IllegalStateException("Generated plan contains unimplemented filtered aggregations"); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/CallExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/CallExpression.java index 5d6307de32629..8bc0f2720b8c4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/CallExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/CallExpression.java @@ -83,7 +83,7 @@ public boolean equals(Object obj) } @Override - public R accept(RowExpressionVisitor visitor, C context) + public R accept(RowExpressionVisitor visitor, C context) { return visitor.visitCall(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/ConstantExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/ConstantExpression.java index 7c6eed8d10744..eceed74ff0b94 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/ConstantExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/ConstantExpression.java @@ -70,7 +70,7 @@ public boolean equals(Object obj) } @Override - public R accept(RowExpressionVisitor visitor, C context) + public R accept(RowExpressionVisitor visitor, C context) { return visitor.visitConstant(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/DeterminismEvaluator.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/DeterminismEvaluator.java index 203720d6b5373..e1f9534b55e2d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/DeterminismEvaluator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/DeterminismEvaluator.java @@ -33,7 +33,7 @@ public boolean isDeterministic(RowExpression expression) } private static class Visitor - implements RowExpressionVisitor + implements RowExpressionVisitor { private final FunctionRegistry registry; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/InputReferenceExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/InputReferenceExpression.java index bdfe3502f6e62..f6e09ea817408 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/InputReferenceExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/InputReferenceExpression.java @@ -60,7 +60,7 @@ public String toString() } @Override - public R accept(RowExpressionVisitor visitor, C context) + public R accept(RowExpressionVisitor visitor, C context) { return visitor.visitInputReference(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/LambdaDefinitionExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/LambdaDefinitionExpression.java index 4d0246660f5ac..39b2439de31a6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/LambdaDefinitionExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/LambdaDefinitionExpression.java @@ -63,7 +63,7 @@ public Type getType() @Override public String toString() { - return "(" + Joiner.on("").join(arguments) + ") -> " + body; + return "(" + Joiner.on(",").join(arguments) + ") -> " + body; } @Override @@ -88,7 +88,7 @@ public int hashCode() } @Override - public R accept(RowExpressionVisitor visitor, C context) + public R accept(RowExpressionVisitor visitor, C context) { return visitor.visitLambda(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java index 4a03a9eb3c4c0..8d0e5a7d48759 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpression.java @@ -27,5 +27,5 @@ public abstract class RowExpression @Override public abstract String toString(); - public abstract R accept(RowExpressionVisitor visitor, C context); + public abstract R accept(RowExpressionVisitor visitor, C context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionVisitor.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionVisitor.java index 91edb5577e378..c1b9eee9c74c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionVisitor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionVisitor.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.relational; -public interface RowExpressionVisitor +public interface RowExpressionVisitor { R visitCall(CallExpression call, C context); R visitInputReference(InputReferenceExpression reference, C context); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java index 527fe60329ca1..10ff4b3250e7d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/Signatures.java @@ -16,6 +16,7 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; @@ -23,7 +24,6 @@ import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.type.LikePatternType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -155,9 +155,14 @@ public static Signature trySignature(Type returnType) return new Signature(TRY, SCALAR, returnType.getTypeSignature()); } - public static Signature bindSignature(Type returnType, Type valueType, Type functionType) + public static Signature bindSignature(Type returnType, List valueTypes, Type functionType) { - return new Signature(BIND, SCALAR, returnType.getTypeSignature(), valueType.getTypeSignature(), functionType.getTypeSignature()); + ImmutableList.Builder typeSignatureBuilder = ImmutableList.builder(); + for (Type valueType : valueTypes) { + typeSignatureBuilder.add(valueType.getTypeSignature()); + } + typeSignatureBuilder.add(functionType.getTypeSignature()); + return new Signature(BIND, SCALAR, returnType.getTypeSignature(), typeSignatureBuilder.build()); } // **************** functions that require varargs and/or complex types (e.g., lists) **************** diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java index 5b18f9d3c3baa..9968d5ec16f1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlToRowExpressionTranslator.java @@ -19,6 +19,8 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.DecimalParseResult; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.RowType; +import com.facebook.presto.spi.type.RowType.RowField; import com.facebook.presto.spi.type.TimeZoneKey; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -54,6 +56,7 @@ import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullLiteral; @@ -67,14 +70,13 @@ import com.facebook.presto.sql.tree.TimestampLiteral; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; -import com.facebook.presto.type.RowType; -import com.facebook.presto.type.RowType.RowField; import com.facebook.presto.type.UnknownType; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import java.util.List; +import java.util.Map; import static com.facebook.presto.metadata.FunctionKind.SCALAR; import static com.facebook.presto.spi.type.BigintType.BIGINT; @@ -132,7 +134,7 @@ private SqlToRowExpressionTranslator() {} public static RowExpression translate( Expression expression, FunctionKind functionKind, - IdentityLinkedHashMap types, + Map, Type> types, FunctionRegistry functionRegistry, TypeManager typeManager, Session session, @@ -154,18 +156,23 @@ private static class Visitor extends AstVisitor { private final FunctionKind functionKind; - private final IdentityLinkedHashMap types; + private final Map, Type> types; private final TypeManager typeManager; private final TimeZoneKey timeZoneKey; - private Visitor(FunctionKind functionKind, IdentityLinkedHashMap types, TypeManager typeManager, TimeZoneKey timeZoneKey) + private Visitor(FunctionKind functionKind, Map, Type> types, TypeManager typeManager, TimeZoneKey timeZoneKey) { this.functionKind = functionKind; - this.types = types; + this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null")); this.typeManager = typeManager; this.timeZoneKey = timeZoneKey; } + private Type getType(Expression node) + { + return types.get(NodeRef.of(node)); + } + @Override protected RowExpression visitExpression(Expression node, Void context) { @@ -175,7 +182,7 @@ protected RowExpression visitExpression(Expression node, Void context) @Override protected RowExpression visitFieldReference(FieldReference node, Void context) { - return field(node.getFieldIndex(), types.get(node)); + return field(node.getFieldIndex(), getType(node)); } @Override @@ -240,14 +247,14 @@ protected RowExpression visitGenericLiteral(GenericLiteral node, Void context) if (JSON.equals(type)) { return call( - new Signature("json_parse", SCALAR, types.get(node).getTypeSignature(), VARCHAR.getTypeSignature()), - types.get(node), + new Signature("json_parse", SCALAR, getType(node).getTypeSignature(), VARCHAR.getTypeSignature()), + getType(node), constant(utf8Slice(node.getValue()), VARCHAR)); } return call( - castSignature(types.get(node), VARCHAR), - types.get(node), + castSignature(getType(node), VARCHAR), + getType(node), constant(utf8Slice(node.getValue()), VARCHAR)); } @@ -255,28 +262,28 @@ protected RowExpression visitGenericLiteral(GenericLiteral node, Void context) protected RowExpression visitTimeLiteral(TimeLiteral node, Void context) { long value; - if (types.get(node).equals(TIME_WITH_TIME_ZONE)) { + if (getType(node).equals(TIME_WITH_TIME_ZONE)) { value = parseTimeWithTimeZone(node.getValue()); } else { // parse in time zone of client value = parseTimeWithoutTimeZone(timeZoneKey, node.getValue()); } - return constant(value, types.get(node)); + return constant(value, getType(node)); } @Override protected RowExpression visitTimestampLiteral(TimestampLiteral node, Void context) { long value; - if (types.get(node).equals(TIMESTAMP_WITH_TIME_ZONE)) { + if (getType(node).equals(TIMESTAMP_WITH_TIME_ZONE)) { value = parseTimestampWithTimeZone(timeZoneKey, node.getValue()); } else { // parse in time zone of client value = parseTimestampWithoutTimeZone(timeZoneKey, node.getValue()); } - return constant(value, types.get(node)); + return constant(value, getType(node)); } @Override @@ -289,7 +296,7 @@ protected RowExpression visitIntervalLiteral(IntervalLiteral node, Void context) else { value = node.getSign().multiplier() * parseDayTimeInterval(node.getValue(), node.getStartField(), node.getEndField()); } - return constant(value, types.get(node)); + return constant(value, getType(node)); } @Override @@ -317,15 +324,15 @@ protected RowExpression visitFunctionCall(FunctionCall node, Void context) .map(Type::getTypeSignature) .collect(toImmutableList()); - Signature signature = new Signature(node.getName().getSuffix(), functionKind, types.get(node).getTypeSignature(), argumentTypes); + Signature signature = new Signature(node.getName().getSuffix(), functionKind, getType(node).getTypeSignature(), argumentTypes); - return call(signature, types.get(node), arguments); + return call(signature, getType(node), arguments); } @Override protected RowExpression visitSymbolReference(SymbolReference node, Void context) { - return new VariableReferenceExpression(node.getName(), types.get(node)); + return new VariableReferenceExpression(node.getName(), getType(node)); } @Override @@ -333,7 +340,7 @@ protected RowExpression visitLambdaExpression(LambdaExpression node, Void contex { RowExpression body = process(node.getBody(), context); - Type type = types.get(node); + Type type = getType(node); List typeParameters = type.getTypeParameters(); List argumentTypes = typeParameters.subList(0, typeParameters.size() - 1); List argumentNames = node.getArguments().stream() @@ -346,14 +353,20 @@ protected RowExpression visitLambdaExpression(LambdaExpression node, Void contex @Override protected RowExpression visitBindExpression(BindExpression node, Void context) { - RowExpression value = process(node.getValue(), context); + ImmutableList.Builder valueTypesBuilder = ImmutableList.builder(); + ImmutableList.Builder argumentsBuilder = ImmutableList.builder(); + for (Expression value : node.getValues()) { + RowExpression valueRowExpression = process(value, context); + valueTypesBuilder.add(valueRowExpression.getType()); + argumentsBuilder.add(valueRowExpression); + } RowExpression function = process(node.getFunction(), context); + argumentsBuilder.add(function); return call( - bindSignature(types.get(node), value.getType(), function.getType()), - types.get(node), - value, - function); + bindSignature(getType(node), valueTypesBuilder.build(), function.getType()), + getType(node), + argumentsBuilder.build()); } @Override @@ -363,8 +376,8 @@ protected RowExpression visitArithmeticBinary(ArithmeticBinaryExpression node, V RowExpression right = process(node.getRight(), context); return call( - arithmeticExpressionSignature(node.getType(), types.get(node), left.getType(), right.getType()), - types.get(node), + arithmeticExpressionSignature(node.getType(), getType(node), left.getType(), right.getType()), + getType(node), left, right); } @@ -379,8 +392,8 @@ protected RowExpression visitArithmeticUnary(ArithmeticUnaryExpression node, Voi return expression; case MINUS: return call( - arithmeticNegationSignature(types.get(node), expression.getType()), - types.get(node), + arithmeticNegationSignature(getType(node), expression.getType()), + getType(node), expression); } @@ -403,14 +416,14 @@ protected RowExpression visitCast(Cast node, Void context) RowExpression value = process(node.getExpression(), context); if (node.isTypeOnly()) { - return changeType(value, types.get(node)); + return changeType(value, getType(node)); } if (node.isSafe()) { - return call(tryCastSignature(types.get(node), value.getType()), types.get(node), value); + return call(tryCastSignature(getType(node), value.getType()), getType(node), value); } - return call(castSignature(types.get(node), value.getType()), types.get(node), value); + return call(castSignature(getType(node), value.getType()), getType(node), value); } private static RowExpression changeType(RowExpression value, Type targetType) @@ -420,7 +433,7 @@ private static RowExpression changeType(RowExpression value, Type targetType) } private static class ChangeTypeVisitor - implements RowExpressionVisitor + implements RowExpressionVisitor { private final Type targetType; @@ -468,7 +481,7 @@ protected RowExpression visitCoalesceExpression(CoalesceExpression node, Void co .collect(toImmutableList()); List argumentTypes = arguments.stream().map(RowExpression::getType).collect(toImmutableList()); - return call(coalesceSignature(types.get(node), argumentTypes), types.get(node), arguments); + return call(coalesceSignature(getType(node), argumentTypes), getType(node), arguments); } @Override @@ -479,13 +492,13 @@ protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Voi arguments.add(process(node.getOperand(), context)); for (WhenClause clause : node.getWhenClauses()) { - arguments.add(call(whenSignature(types.get(clause)), - types.get(clause), + arguments.add(call(whenSignature(getType(clause)), + getType(clause), process(clause.getOperand(), context), process(clause.getResult(), context))); } - Type returnType = types.get(node); + Type returnType = getType(node); arguments.add(node.getDefaultValue() .map((value) -> process(value, context)) @@ -519,12 +532,12 @@ protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, */ RowExpression expression = node.getDefaultValue() .map((value) -> process(value, context)) - .orElse(constantNull(types.get(node))); + .orElse(constantNull(getType(node))); for (WhenClause clause : Lists.reverse(node.getWhenClauses())) { expression = call( - Signatures.ifSignature(types.get(node)), - types.get(node), + Signatures.ifSignature(getType(node)), + getType(node), process(clause.getOperand(), context), process(clause.getResult(), context), expression); @@ -536,7 +549,7 @@ protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, @Override protected RowExpression visitDereferenceExpression(DereferenceExpression node, Void context) { - RowType rowType = (RowType) types.get(node.getBase()); + RowType rowType = (RowType) getType(node.getBase()); List fields = rowType.getFields(); int index = -1; for (int i = 0; i < fields.size(); i++) { @@ -547,7 +560,7 @@ protected RowExpression visitDereferenceExpression(DereferenceExpression node, V } } checkState(index >= 0, "could not find field name: %s", node.getFieldName()); - Type returnType = types.get(node); + Type returnType = getType(node); return call(dereferenceSignature(returnType, rowType), returnType, process(node.getBase(), context), constant(index, INTEGER)); } @@ -563,16 +576,16 @@ protected RowExpression visitIfExpression(IfExpression node, Void context) arguments.add(process(node.getFalseValue().get(), context)); } else { - arguments.add(constantNull(types.get(node))); + arguments.add(constantNull(getType(node))); } - return call(Signatures.ifSignature(types.get(node)), types.get(node), arguments.build()); + return call(Signatures.ifSignature(getType(node)), getType(node), arguments.build()); } @Override protected RowExpression visitTryExpression(TryExpression node, Void context) { - return call(Signatures.trySignature(types.get(node)), types.get(node), process(node.getInnerExpression(), context)); + return call(Signatures.trySignature(getType(node)), getType(node), process(node.getInnerExpression(), context)); } @Override @@ -620,8 +633,8 @@ protected RowExpression visitNullIfExpression(NullIfExpression node, Void contex RowExpression second = process(node.getSecond(), context); return call( - nullIfSignature(types.get(node), first.getType(), second.getType()), - types.get(node), + nullIfSignature(getType(node), first.getType(), second.getType()), + getType(node), first, second); } @@ -662,8 +675,8 @@ protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void RowExpression index = process(node.getIndex(), context); return call( - subscriptSignature(types.get(node), base.getType(), index.getType()), - types.get(node), + subscriptSignature(getType(node), base.getType(), index.getType()), + getType(node), base, index); } @@ -677,7 +690,7 @@ protected RowExpression visitArrayConstructor(ArrayConstructor node, Void contex List argumentTypes = arguments.stream() .map(RowExpression::getType) .collect(toImmutableList()); - return call(arrayConstructorSignature(types.get(node), argumentTypes), types.get(node), arguments); + return call(arrayConstructorSignature(getType(node), argumentTypes), getType(node), arguments); } @Override @@ -686,9 +699,9 @@ protected RowExpression visitRow(Row node, Void context) List arguments = node.getItems().stream() .map(value -> process(value, context)) .collect(toImmutableList()); - Type returnType = types.get(node); + Type returnType = getType(node); List argumentTypes = node.getItems().stream() - .map(value -> types.get(value)) + .map(this::getType) .collect(toImmutableList()); return call(rowConstructorSignature(returnType, argumentTypes), returnType, arguments); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/VariableReferenceExpression.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/VariableReferenceExpression.java index 51787902b3e7c..46ed7a1a2bff4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/VariableReferenceExpression.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/VariableReferenceExpression.java @@ -55,7 +55,7 @@ public String toString() } @Override - public R accept(RowExpressionVisitor visitor, C context) + public R accept(RowExpressionVisitor visitor, C context) { return visitor.visitVariableReference(this, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java index 9e79c713d9777..15e5470e2ccd9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java @@ -72,7 +72,7 @@ public RowExpression optimize(RowExpression expression) } private class Visitor - implements RowExpressionVisitor + implements RowExpressionVisitor { @Override public RowExpression visitInputReference(InputReferenceExpression reference, Void context) @@ -129,15 +129,23 @@ public RowExpression visitCall(CallExpression call, Void context) return call(signature, call.getType(), arguments); } case BIND: { - checkState(call.getArguments().size() == 2, BIND + " function should have 2 arguments. Got " + call.getArguments().size()); - RowExpression optimizedValue = call.getArguments().get(0).accept(this, context); - RowExpression optimizedFunction = call.getArguments().get(1).accept(this, context); - if (optimizedValue instanceof ConstantExpression && optimizedFunction instanceof ConstantExpression) { - // Here, optimizedValue and optimizedFunction should be merged together into a new ConstantExpression. + checkState(call.getArguments().size() >= 1, BIND + " function should have at least 1 argument. Got " + call.getArguments().size()); + + boolean allConstantExpression = true; + ImmutableList.Builder optimizedArgumentsBuilder = ImmutableList.builder(); + for (RowExpression argument : call.getArguments()) { + RowExpression optimizedArgument = argument.accept(this, context); + if (!(optimizedArgument instanceof ConstantExpression)) { + allConstantExpression = false; + } + optimizedArgumentsBuilder.add(optimizedArgument); + } + if (allConstantExpression) { + // Here, optimizedArguments should be merged together into a new ConstantExpression. // It's not implemented because it would be dead code anyways because visitLambda does not produce ConstantExpression. throw new UnsupportedOperationException(); } - return call(signature, call.getType(), ImmutableList.of(optimizedValue, optimizedFunction)); + return call(signature, call.getType(), optimizedArgumentsBuilder.build()); } case NULL_IF: case SWITCH: diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java index 23227851d1161..56eb988648a3e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java @@ -143,7 +143,7 @@ public Statement rewrite( List parameters, AccessControl accessControl) { - return (Statement) new Visitor(metadata, parser, session, parameters, accessControl).process(node, null); + return (Statement) new Visitor(metadata, parser, session, parameters, accessControl, queryExplainer).process(node, null); } private static class Visitor @@ -154,14 +154,17 @@ private static class Visitor private final SqlParser sqlParser; List parameters; private final AccessControl accessControl; + private Optional queryExplainer; + + public Visitor(Metadata metadata, SqlParser sqlParser, Session session, List parameters, AccessControl accessControl, Optional queryExplainer) - public Visitor(Metadata metadata, SqlParser sqlParser, Session session, List parameters, AccessControl accessControl) { this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); this.session = requireNonNull(session, "session is null"); this.parameters = requireNonNull(parameters, "parameters is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); + this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); } @Override @@ -544,6 +547,7 @@ private static String getFunctionType(SqlFunction function) } @Override + protected Node visitShowSession(ShowSession node, Void context) { ImmutableList.Builder rows = ImmutableList.builder(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java new file mode 100644 index 0000000000000..c5e482abbd1f8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java @@ -0,0 +1,319 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.rewrite; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.QualifiedObjectName; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.security.AccessControl; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.ColumnStatistics; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.sql.QueryUtil; +import com.facebook.presto.sql.analyzer.QueryExplainer; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.tree.AllColumns; +import com.facebook.presto.sql.tree.AstVisitor; +import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.DoubleLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.IsNotNullPredicate; +import com.facebook.presto.sql.tree.IsNullPredicate; +import com.facebook.presto.sql.tree.Literal; +import com.facebook.presto.sql.tree.LogicalBinaryExpression; +import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.NotExpression; +import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.Query; +import com.facebook.presto.sql.tree.QuerySpecification; +import com.facebook.presto.sql.tree.Row; +import com.facebook.presto.sql.tree.SelectItem; +import com.facebook.presto.sql.tree.ShowStats; +import com.facebook.presto.sql.tree.SingleColumn; +import com.facebook.presto.sql.tree.Statement; +import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableSubquery; +import com.facebook.presto.sql.tree.Values; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeSet; + +import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; +import static com.facebook.presto.sql.QueryUtil.aliased; +import static com.facebook.presto.sql.QueryUtil.selectAll; +import static com.facebook.presto.sql.QueryUtil.simpleQuery; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; + +public class ShowStatsRewrite + implements StatementRewrite.Rewrite +{ + private static final List> ALLOWED_SHOW_STATS_WHERE_EXPRESSION_TYPES = ImmutableList.of( + Literal.class, Identifier.class, ComparisonExpression.class, LogicalBinaryExpression.class, NotExpression.class, IsNullPredicate.class, IsNotNullPredicate.class); + + @Override + public Statement rewrite(Session session, Metadata metadata, SqlParser parser, Optional queryExplainer, Statement node, List parameters, AccessControl accessControl) + { + return (Statement) new Visitor(metadata, session, parameters, queryExplainer).process(node, null); + } + + private static class Visitor + extends AstVisitor + { + private final Metadata metadata; + private final Session session; + private final List parameters; + private final Optional queryExplainer; + + public Visitor(Metadata metadata, Session session, List parameters, Optional queryExplainer) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); + this.parameters = requireNonNull(parameters, "parameters is null"); + this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); + } + + @Override + protected Node visitShowStats(ShowStats node, Void context) + { + validateShowStats(node); + checkState(queryExplainer.isPresent(), "Query explainer must be provided for SHOW STATS SELECT"); + + if (node.getRelation() instanceof TableSubquery) { + QuerySpecification specification = (QuerySpecification) ((TableSubquery) node.getRelation()).getQuery().getQueryBody(); + Table table = (Table) specification.getFrom().get(); + Constraint constraint = getConstraint(specification); + return rewriteShowStats(node, table, constraint); + } + else if (node.getRelation() instanceof Table) { + Table table = (Table) node.getRelation(); + return rewriteShowStats(node, table, Constraint.alwaysTrue()); + } + else { + throw new IllegalArgumentException("Expected either TableSubquery or Table as relation"); + } + } + + private void validateShowStats(ShowStats node) + { + // The following properties of SELECT subquery are required: + // - only one relation in FROM + // - only plain columns in projection + // - only plain columns and constants in WHERE + // - no group by + // - no having + // - no set quantifier + if (!(node.getRelation() instanceof Table || node.getRelation() instanceof TableSubquery)) { + throw new SemanticException(NOT_SUPPORTED, node, "Only table and simple table subquery can be passed as argument to SHOW STATS clause"); + } + + if (node.getRelation() instanceof TableSubquery) { + Query query = ((TableSubquery) node.getRelation()).getQuery(); + check(query.getQueryBody() instanceof QuerySpecification, node, "Only table and simple table subquery can be passed as argument to SHOW STATS clause"); + QuerySpecification querySpecification = (QuerySpecification) query.getQueryBody(); + + check(querySpecification.getFrom().isPresent(), node, "There must be exactly one table in query passed to SHOW STATS SELECT clause"); + check(querySpecification.getFrom().get() instanceof Table, node, "There must be exactly one table in query passed to SHOW STATS SELECT clause"); + check(!query.getWith().isPresent(), node, "WITH is not supported by SHOW STATS SELECT clause"); + check(!querySpecification.getOrderBy().isPresent(), node, "ORDER BY is not supported in SHOW STATS SELECT clause"); + check(!querySpecification.getLimit().isPresent(), node, "LIMIT is not supported by SHOW STATS SELECT clause"); + check(!querySpecification.getHaving().isPresent(), node, "HAVING is not supported in SHOW STATS SELECT clause"); + check(!querySpecification.getGroupBy().isPresent(), node, "GROUP BY is not supported in SHOW STATS SELECT clause"); + check(!querySpecification.getSelect().isDistinct(), node, "DISTINCT is not supported by SHOW STATS SELECT clause"); + for (SelectItem selectItem : querySpecification.getSelect().getSelectItems()) { + if (selectItem instanceof AllColumns) { + continue; + } + check(selectItem instanceof SingleColumn, node, "Only * and column references are supported by SHOW STATS SELECT clause"); + SingleColumn columnSelect = (SingleColumn) selectItem; + check(columnSelect.getExpression() instanceof Identifier, node, "Only * and column references are supported by SHOW STATS SELECT clause"); + } + + querySpecification.getWhere().ifPresent((expression) -> validateShowStatsWhereExpression(expression, node)); + } + } + + void validateShowStatsWhereExpression(Expression expression, ShowStats node) + { + check(ALLOWED_SHOW_STATS_WHERE_EXPRESSION_TYPES.stream().anyMatch(clazz -> clazz.isInstance(expression)), node, "Only literals, column references, comparators, is (not) null and logical operators are allowed in WHERE of SHOW STATS SELECT clause"); + + if (expression instanceof NotExpression) { + validateShowStatsWhereExpression(((NotExpression) expression).getValue(), node); + } + else if (expression instanceof LogicalBinaryExpression) { + validateShowStatsWhereExpression(((LogicalBinaryExpression) expression).getLeft(), node); + validateShowStatsWhereExpression(((LogicalBinaryExpression) expression).getRight(), node); + } + else if (expression instanceof ComparisonExpression) { + validateShowStatsWhereExpression(((ComparisonExpression) expression).getLeft(), node); + validateShowStatsWhereExpression(((ComparisonExpression) expression).getRight(), node); + } + else if (expression instanceof IsNullPredicate) { + validateShowStatsWhereExpression(((IsNullPredicate) expression).getValue(), node); + } + else if (expression instanceof IsNotNullPredicate) { + validateShowStatsWhereExpression(((IsNotNullPredicate) expression).getValue(), node); + } + } + + private Node rewriteShowStats(ShowStats node, Table table, Constraint constraint) + { + TableHandle tableHandle = getTableHandle(node, table.getName()); + TableStatistics tableStatistics = metadata.getTableStatistics(session, tableHandle, constraint); + List statisticsNames = findUniqueStatisticsNames(tableStatistics); + List resultColumnNames = buildColumnsNames(statisticsNames); + List selectItems = buildSelectItems(resultColumnNames); + Map columnNames = getStatisticsColumnNames(tableStatistics, node, table.getName()); + + List resultRows = buildStatisticsRows(tableStatistics, columnNames, statisticsNames); + + return simpleQuery(selectAll(selectItems), + aliased(new Values(resultRows), + "table_stats_for_" + table.getName(), + resultColumnNames)); + } + + private static void check(boolean condition, ShowStats node, String message) + { + if (!condition) { + throw new SemanticException(NOT_SUPPORTED, node, message); + } + } + + @Override + protected Node visitNode(Node node, Void context) + { + return node; + } + + private Constraint getConstraint(QuerySpecification specification) + { + if (!specification.getWhere().isPresent()) { + return Constraint.alwaysTrue(); + } + + Plan plan = queryExplainer.get().getLogicalPlan(session, new Query(Optional.empty(), specification, Optional.empty(), Optional.empty()), parameters); + + Optional scanNode = searchFrom(plan.getRoot()) + .where(TableScanNode.class::isInstance) + .findFirst(); + + if (!scanNode.isPresent()) { + return new Constraint<>(TupleDomain.none(), bindings -> true); + } + + return new Constraint<>(scanNode.get().getCurrentConstraint(), bindings -> true); + } + + private Map getStatisticsColumnNames(TableStatistics statistics, ShowStats node, QualifiedName tableName) + { + TableHandle tableHandle = getTableHandle(node, tableName); + + return statistics.getColumnStatistics() + .keySet().stream() + .collect(toMap(identity(), column -> metadata.getColumnMetadata(session, tableHandle, column).getName())); + } + + private TableHandle getTableHandle(ShowStats node, QualifiedName table) + { + QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, node, table); + return metadata.getTableHandle(session, qualifiedTableName) + .orElseThrow(() -> new SemanticException(MISSING_TABLE, node, "Table %s not found", table)); + } + + private static List findUniqueStatisticsNames(TableStatistics tableStatistics) + { + TreeSet statisticsKeys = new TreeSet<>(); + statisticsKeys.addAll(tableStatistics.getTableStatistics().keySet()); + for (ColumnStatistics columnStats : tableStatistics.getColumnStatistics().values()) { + statisticsKeys.addAll(columnStats.getStatistics().keySet()); + } + return unmodifiableList(new ArrayList(statisticsKeys)); + } + + static List buildStatisticsRows(TableStatistics tableStatistics, Map columnNames, List statisticsNames) + { + ImmutableList.Builder rowsBuilder = ImmutableList.builder(); + + // Stats for columns + for (Map.Entry columnStats : tableStatistics.getColumnStatistics().entrySet()) { + Map columnStatisticsValues = columnStats.getValue().getStatistics(); + rowsBuilder.add(createStatsRow(Optional.of(columnNames.get(columnStats.getKey())), statisticsNames, columnStatisticsValues)); + } + + // Stats for whole table + rowsBuilder.add(createStatsRow(Optional.empty(), statisticsNames, tableStatistics.getTableStatistics())); + + return rowsBuilder.build(); + } + + static List buildSelectItems(List columnNames) + { + return columnNames.stream().map(QueryUtil::unaliasedName).collect(toImmutableList()); + } + + static List buildColumnsNames(List statisticsNames) + { + ImmutableList.Builder columnNamesBuilder = ImmutableList.builder(); + columnNamesBuilder.add("column_name"); + columnNamesBuilder.addAll(statisticsNames); + return columnNamesBuilder.build(); + } + + private static Row createStatsRow(Optional columnName, List statisticsNames, Map columnStatisticsValues) + { + ImmutableList.Builder rowValues = ImmutableList.builder(); + Expression columnNameExpression = columnName.map(name -> (Expression) new StringLiteral(name)).orElse(new Cast(new NullLiteral(), VARCHAR)); + + rowValues.add(columnNameExpression); + for (String statName : statisticsNames) { + rowValues.add(createStatisticValueOrNull(columnStatisticsValues, statName)); + } + return new Row(rowValues.build()); + } + + private static Expression createStatisticValueOrNull(Map columnStatisticsValues, String statName) + { + if (columnStatisticsValues.containsKey(statName) && !columnStatisticsValues.get(statName).isValueUnknown()) { + return new DoubleLiteral(Double.toString(columnStatisticsValues.get(statName).getValue())); + } + else { + return new NullLiteral(); + } + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java index 1974936a48762..508d6217eebe2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java @@ -33,6 +33,7 @@ public final class StatementRewrite new DescribeInputRewrite(), new DescribeOutputRewrite(), new ShowQueriesRewrite(), + new ShowStatsRewrite(), new ExplainRewrite()); private StatementRewrite() {} diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 647a5d94dd50d..300e072d46c18 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -29,6 +29,8 @@ import com.facebook.presto.connector.system.SchemaPropertiesSystemTable; import com.facebook.presto.connector.system.TablePropertiesSystemTable; import com.facebook.presto.connector.system.TransactionsSystemTable; +import com.facebook.presto.cost.CoefficientBasedCostCalculator; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.CommitTask; import com.facebook.presto.execution.CreateTableTask; import com.facebook.presto.execution.CreateViewTask; @@ -209,6 +211,7 @@ public class LocalQueryRunner private final PageSorter pageSorter; private final PageIndexerFactory pageIndexerFactory; private final MetadataManager metadata; + private final CostCalculator costCalculator; private final TestingAccessControlManager accessControl; private final SplitManager splitManager; private final BlockEncodingSerde blockEncodingSerde; @@ -283,6 +286,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new SchemaPropertyManager(), new TablePropertyManager(), transactionManager); + this.costCalculator = new CoefficientBasedCostCalculator(metadata); this.accessControl = new TestingAccessControlManager(transactionManager); this.pageSourceManager = new PageSourceManager(); @@ -399,6 +403,12 @@ public Metadata getMetadata() return metadata; } + @Override + public CostCalculator getCostCalculator() + { + return costCalculator; + } + @Override public TestingAccessControlManager getAccessControl() { @@ -556,7 +566,7 @@ public List createDrivers(Session session, @Language("SQL") String sql, Plan plan = createPlan(session, sql); if (printPlan) { - System.out.println(PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, session)); + System.out.println(PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), metadata, costCalculator, session)); } SubPlan subplan = PlanFragmenter.createSubPlans(session, metadata, plan); @@ -567,6 +577,7 @@ public List createDrivers(Session session, @Language("SQL") String sql, LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( metadata, sqlParser, + costCalculator, Optional.empty(), pageSourceManager, indexManager, @@ -661,11 +672,15 @@ public Plan createPlan(Session session, @Language("SQL") String sql, LogicalPlan assertFormattedSql(sqlParser, statement); + return createPlan(session, sql, getPlanOptimizers(forceSingleNode), stage); + } + + public List getPlanOptimizers(boolean forceSingleNode) + { FeaturesConfig featuresConfig = new FeaturesConfig() .setDistributedIndexJoinsEnabled(false) .setOptimizeHashGeneration(true); - PlanOptimizers planOptimizers = new PlanOptimizers(metadata, sqlParser, featuresConfig, forceSingleNode, new MBeanExporter(new TestingMBeanServer())); - return createPlan(session, sql, planOptimizers.get(), stage); + return new PlanOptimizers(metadata, sqlParser, featuresConfig, forceSingleNode, new MBeanExporter(new TestingMBeanServer())).get(); } public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers) @@ -693,10 +708,11 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List getTableHandleClass() { - return TestingMetadata.InMemoryTableHandle.class; + return TestingTableHandle.class; } @Override @@ -40,7 +42,7 @@ public Class getTableLayoutHandleClass() @Override public Class getColumnHandleClass() { - return TestingMetadata.InMemoryColumnHandle.class; + return TestingColumnHandle.class; } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingMetadata.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingMetadata.java index fd5830990a3c7..0d5d02e7d6898 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestingMetadata.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingMetadata.java @@ -34,6 +34,8 @@ import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.security.Privilege; import com.facebook.presto.spi.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; @@ -43,7 +45,9 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -77,7 +81,7 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable if (!tables.containsKey(tableName)) { return null; } - return new InMemoryTableHandle(tableName); + return new TestingTableHandle(tableName); } @Override @@ -108,7 +112,7 @@ public Map getColumnHandles(ConnectorSession session, Conn ImmutableMap.Builder builder = ImmutableMap.builder(); int index = 0; for (ColumnMetadata columnMetadata : getTableMetadata(session, tableHandle).getColumns()) { - builder.put(columnMetadata.getName(), new InMemoryColumnHandle(columnMetadata.getName(), index, columnMetadata.getType())); + builder.put(columnMetadata.getName(), new TestingColumnHandle(columnMetadata.getName(), index, columnMetadata.getType())); index++; } return builder.build(); @@ -134,7 +138,7 @@ public Map> listTableColumns(ConnectorSess public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) { SchemaTableName tableName = getTableName(tableHandle); - int columnIndex = ((InMemoryColumnHandle) columnHandle).getOrdinalPosition(); + int columnIndex = ((TestingColumnHandle) columnHandle).getOrdinalPosition(); return tables.get(tableName).getColumns().get(columnIndex); } @@ -279,41 +283,63 @@ public void clear() private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) { requireNonNull(tableHandle, "tableHandle is null"); - checkArgument(tableHandle instanceof InMemoryTableHandle, "tableHandle is not an instance of InMemoryTableHandle"); - InMemoryTableHandle inMemoryTableHandle = (InMemoryTableHandle) tableHandle; - return inMemoryTableHandle.getTableName(); + checkArgument(tableHandle instanceof TestingTableHandle, "tableHandle is not an instance of TestingTableHandle"); + TestingTableHandle testingTableHandle = (TestingTableHandle) tableHandle; + return testingTableHandle.getTableName(); } - public static class InMemoryTableHandle + public static class TestingTableHandle implements ConnectorTableHandle { private final SchemaTableName tableName; - public InMemoryTableHandle(SchemaTableName schemaTableName) + public TestingTableHandle() { - this.tableName = schemaTableName; + this(new SchemaTableName("test-schema", "test-table")); } + @JsonCreator + public TestingTableHandle(@JsonProperty("tableName") SchemaTableName schemaTableName) + { + this.tableName = requireNonNull(schemaTableName, "schemaTableName is null"); + } + + @JsonProperty public SchemaTableName getTableName() { return tableName; } } - public static class InMemoryColumnHandle + public static class TestingColumnHandle implements ColumnHandle { private final String name; - private final int ordinalPosition; - private final Type type; + private final OptionalInt ordinalPosition; + private final Optional type; - public InMemoryColumnHandle(String name, int ordinalPosition, Type type) + public TestingColumnHandle(String name) { - this.name = name; - this.ordinalPosition = ordinalPosition; - this.type = type; + this(name, OptionalInt.empty(), Optional.empty()); } + public TestingColumnHandle(String name, int ordinalPosition, Type type) + { + this(name, OptionalInt.of(ordinalPosition), Optional.of(type)); + } + + @JsonCreator + public TestingColumnHandle( + @JsonProperty("name") String name, + @JsonProperty("ordinalPosition") OptionalInt ordinalPosition, + @JsonProperty("type") Optional type) + { + this.name = requireNonNull(name, "name is null"); + this.ordinalPosition = requireNonNull(ordinalPosition, "ordinalPosition is null"); + this.type = requireNonNull(type, "type is null"); + } + + @JsonProperty public String getName() { return name; @@ -321,12 +347,45 @@ public String getName() public int getOrdinalPosition() { - return ordinalPosition; + return ordinalPosition.orElseThrow(() -> new UnsupportedOperationException("Testing handle was created without ordinal position")); } public Type getType() + { + return type.orElseThrow(() -> new UnsupportedOperationException("Testing handle was created without type")); + } + + @JsonProperty("ordinalPosition") + public OptionalInt getJsonOrdinalPosition() + { + return ordinalPosition; + } + + @JsonProperty("type") + public Optional getJsonType() { return type; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestingColumnHandle that = (TestingColumnHandle) o; + return Objects.equals(name, that.name) && + Objects.equals(ordinalPosition, that.ordinalPosition) && + Objects.equals(type, that.type); + } + + @Override + public int hashCode() + { + return Objects.hash(name, ordinalPosition, type); + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java index 4a7da46c6acfa..6f29a974ba0d9 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/ArrayParametricType.java @@ -13,10 +13,12 @@ */ package com.facebook.presto.type; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; import java.util.List; @@ -39,7 +41,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { checkArgument(parameters.size() == 1, "Array type expects exactly one type as a parameter, got %s", parameters); checkArgument( diff --git a/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java b/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java index 1a6f1371911eb..81a2f642400d1 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java +++ b/presto-main/src/main/java/com/facebook/presto/type/BooleanOperators.java @@ -32,6 +32,7 @@ import static com.facebook.presto.spi.function.OperatorType.LESS_THAN; import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; +import static java.lang.Float.floatToRawIntBits; import static java.nio.charset.StandardCharsets.US_ASCII; public final class BooleanOperators @@ -99,6 +100,13 @@ public static double castToDouble(@SqlType(StandardTypes.BOOLEAN) boolean value) return value ? 1 : 0; } + @ScalarOperator(CAST) + @SqlType(StandardTypes.REAL) + public static long castToReal(@SqlType(StandardTypes.BOOLEAN) boolean value) + { + return value ? floatToRawIntBits(1.0f) : floatToRawIntBits(0.0f); + } + @ScalarOperator(CAST) @SqlType(StandardTypes.BIGINT) public static long castToBigint(@SqlType(StandardTypes.BOOLEAN) boolean value) diff --git a/presto-main/src/main/java/com/facebook/presto/type/CharParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/CharParametricType.java index 44a37e9f9dedc..fc489d55f6105 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/CharParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/CharParametricType.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; import java.util.List; @@ -34,7 +35,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { if (parameters.isEmpty()) { return createCharType(1); diff --git a/presto-main/src/main/java/com/facebook/presto/type/DecimalParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/DecimalParametricType.java index c95c6aca88a86..276ab109f7163 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/DecimalParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/DecimalParametricType.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; import java.util.List; @@ -33,7 +34,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { switch (parameters.size()) { case 0: diff --git a/presto-main/src/main/java/com/facebook/presto/type/FunctionParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/FunctionParametricType.java index 61ba4ddc2cea9..cc0db90837dba 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/FunctionParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/FunctionParametricType.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; import java.util.List; @@ -40,7 +41,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { checkArgument(parameters.size() >= 1, "Function type must have at least one parameter, got %s", parameters); checkArgument( diff --git a/presto-main/src/main/java/com/facebook/presto/type/FunctionType.java b/presto-main/src/main/java/com/facebook/presto/type/FunctionType.java index 04214cbd3cc37..a180b772de56e 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/FunctionType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/FunctionType.java @@ -17,30 +17,30 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.type.AbstractType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; -import java.lang.invoke.MethodHandle; import java.util.List; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class FunctionType - extends AbstractType + implements Type { public static final String NAME = "function"; + private final TypeSignature signature; private final Type returnType; private final List argumentTypes; public FunctionType(List argumentTypes, Type returnType) { - super(new TypeSignature(NAME, typeParameters(argumentTypes, returnType)), MethodHandle.class); + this.signature = new TypeSignature(NAME, typeParameters(argumentTypes, returnType)); this.returnType = requireNonNull(returnType, "returnType is null"); this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); } @@ -69,41 +69,149 @@ public List getArgumentTypes() } @Override - public Object getObjectValue(ConnectorSession session, Block block, int position) + public List getTypeParameters() { - throw new UnsupportedOperationException(); + return ImmutableList.builder().addAll(argumentTypes).add(returnType).build(); } @Override - public void appendTo(Block block, int position, BlockBuilder blockBuilder) + public final TypeSignature getTypeSignature() { - throw new UnsupportedOperationException(); + return signature; } @Override - public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public String getDisplayName() + { + ImmutableList names = getTypeParameters().stream() + .map(Type::getDisplayName) + .collect(toImmutableList()); + return "function<" + Joiner.on(",").join(names) + ">"; + } + + @Override + public final Class getJavaType() + { + throw new UnsupportedOperationException(getTypeSignature() + " type does not have Java type"); + } + + @Override + public boolean isComparable() + { + return false; + } + + @Override + public boolean isOrderable() + { + return false; + } + + @Override + public long hash(Block block, int position) + { + throw new UnsupportedOperationException(getTypeSignature() + " type is not comparable"); + } + + @Override + public boolean equalTo(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) + { + throw new UnsupportedOperationException(getTypeSignature() + " type is not comparable"); + } + + @Override + public int compareTo(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) + { + throw new UnsupportedOperationException(getTypeSignature() + " type is not orderable"); + } + + @Override + public boolean getBoolean(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeBoolean(BlockBuilder blockBuilder, boolean value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public long getLong(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeLong(BlockBuilder blockBuilder, long value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public double getDouble(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeDouble(BlockBuilder blockBuilder, double value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Slice getSlice(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int length) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Object getObject(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public Object getObjectValue(ConnectorSession session, Block block, int position) { throw new UnsupportedOperationException(); } @Override - public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public void appendTo(Block block, int position, BlockBuilder blockBuilder) { throw new UnsupportedOperationException(); } @Override - public List getTypeParameters() + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { - return ImmutableList.builder().addAll(argumentTypes).add(returnType).build(); + throw new UnsupportedOperationException(); } @Override - public String getDisplayName() + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) { - ImmutableList names = getTypeParameters().stream() - .map(Type::getDisplayName) - .collect(toImmutableList()); - return "function<" + Joiner.on(",").join(names) + ">"; + throw new UnsupportedOperationException(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/ListLiteralType.java b/presto-main/src/main/java/com/facebook/presto/type/ListLiteralType.java new file mode 100644 index 0000000000000..1e2e7740eece2 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/type/ListLiteralType.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.type; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.AbstractType; +import com.facebook.presto.spi.type.TypeSignature; + +import java.util.List; + +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; + +/** + * This type is used to store literals in a Java list to avoid needless serialization and + * deserialization of blocks for function arguments which stay constant. + */ +public class ListLiteralType + extends AbstractType +{ + public static final ListLiteralType LIST_LITERAL = new ListLiteralType(); + public static final String NAME = "ListLiteral"; + + public ListLiteralType() + { + super(new TypeSignature(NAME), List.class); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "ListLiteral type cannot be serialized"); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "ListLiteral type cannot be serialized"); + } + + @Override + public Object getObjectValue(ConnectorSession session, Block block, int position) + { + throw new UnsupportedOperationException(); + } + + @Override + public void appendTo(Block block, int position, BlockBuilder blockBuilder) + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java index 05cc2e9b6bbf9..9e7a3635b2244 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/MapParametricType.java @@ -13,23 +13,31 @@ */ package com.facebook.presto.type; +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; +import com.google.common.collect.ImmutableList; +import java.lang.invoke.MethodHandle; import java.util.List; +import static com.facebook.presto.spi.block.MethodHandleUtil.compose; +import static com.facebook.presto.spi.block.MethodHandleUtil.nativeValueGetter; import static com.google.common.base.Preconditions.checkArgument; public final class MapParametricType implements ParametricType { - public static final MapParametricType MAP = new MapParametricType(); + private final boolean useNewMapBlock; - private MapParametricType() + public MapParametricType(boolean useNewMapBlock) { + this.useNewMapBlock = useNewMapBlock; } @Override @@ -39,7 +47,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { checkArgument(parameters.size() == 2, "Expected two parameters, got %s", parameters); TypeParameter firstParameter = parameters.get(0); @@ -48,6 +56,19 @@ public Type createType(List parameters) firstParameter.getKind() == ParameterKind.TYPE && secondParameter.getKind() == ParameterKind.TYPE, "Expected key and type to be types, got %s", parameters); - return new MapType(firstParameter.getType(), secondParameter.getType()); + + Type keyType = firstParameter.getType(); + Type valueType = secondParameter.getType(); + MethodHandle keyNativeEquals = typeManager.resolveOperator(OperatorType.EQUAL, ImmutableList.of(keyType, keyType)); + MethodHandle keyBlockNativeEquals = compose(keyNativeEquals, nativeValueGetter(keyType)); + MethodHandle keyNativeHashCode = typeManager.resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(keyType)); + MethodHandle keyBlockHashCode = compose(keyNativeHashCode, nativeValueGetter(keyType)); + return new MapType( + useNewMapBlock, + keyType, + valueType, + useNewMapBlock ? keyBlockNativeEquals : null, + useNewMapBlock ? keyNativeHashCode : null, + useNewMapBlock ? keyBlockHashCode : null); } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java index 49cab4bde124f..173983ee6fdec 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/RowParametricType.java @@ -16,8 +16,10 @@ import com.facebook.presto.spi.type.NamedType; import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; import java.util.List; @@ -42,7 +44,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { checkArgument(!parameters.isEmpty(), "Row type must have at least one parameter"); checkArgument( diff --git a/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java b/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java index 65db2b14ae088..cc6484979806c 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/type/TypeJsonUtils.java @@ -18,9 +18,12 @@ import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; import com.facebook.presto.spi.type.FixedWidthType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; @@ -149,7 +152,12 @@ else if (type.getJavaType() == boolean.class) { type.writeBoolean(blockBuilder, parser.getBooleanValue()); } else if (type.getJavaType() == long.class) { - type.writeLong(blockBuilder, parser.getLongValue()); + if (type.equals(REAL)) { + type.writeLong(blockBuilder, floatToRawIntBits(parser.getFloatValue())); + } + else { + type.writeLong(blockBuilder, parser.getLongValue()); + } } else if (type.getJavaType() == double.class) { type.writeDouble(blockBuilder, getDoubleValue(parser)); @@ -222,6 +230,7 @@ public static boolean canCastFromJson(Type type) baseType.equals(StandardTypes.INTEGER) || baseType.equals(StandardTypes.BIGINT) || baseType.equals(StandardTypes.DOUBLE) || + baseType.equals(StandardTypes.REAL) || baseType.equals(StandardTypes.VARCHAR) || baseType.equals(StandardTypes.DECIMAL) || baseType.equals(StandardTypes.JSON)) { @@ -245,6 +254,7 @@ private static boolean isValidJsonObjectKeyType(Type type) baseType.equals(StandardTypes.INTEGER) || baseType.equals(StandardTypes.BIGINT) || baseType.equals(StandardTypes.DOUBLE) || + baseType.equals(StandardTypes.REAL) || baseType.equals(StandardTypes.DECIMAL) || baseType.equals(StandardTypes.VARCHAR); } diff --git a/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java b/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java index 09a330617a410..4c6ba635494bb 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/type/TypeRegistry.java @@ -13,8 +13,12 @@ */ package com.facebook.presto.type; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; @@ -23,12 +27,15 @@ import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import javax.annotation.concurrent.ThreadSafe; import javax.inject.Inject; +import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -67,7 +74,7 @@ import static com.facebook.presto.type.JsonPathType.JSON_PATH; import static com.facebook.presto.type.JsonType.JSON; import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN; -import static com.facebook.presto.type.MapParametricType.MAP; +import static com.facebook.presto.type.ListLiteralType.LIST_LITERAL; import static com.facebook.presto.type.Re2JRegexpType.RE2J_REGEXP; import static com.facebook.presto.type.RowParametricType.ROW; import static com.facebook.presto.type.UnknownType.UNKNOWN; @@ -82,13 +89,22 @@ public final class TypeRegistry private final ConcurrentMap types = new ConcurrentHashMap<>(); private final ConcurrentMap parametricTypes = new ConcurrentHashMap<>(); + private FunctionRegistry functionRegistry; + + @VisibleForTesting public TypeRegistry() { - this(ImmutableSet.of()); + this(ImmutableSet.of(), new FeaturesConfig()); } - @Inject + @VisibleForTesting public TypeRegistry(Set types) + { + this(ImmutableSet.of(), new FeaturesConfig()); + } + + @Inject + public TypeRegistry(Set types, FeaturesConfig featuresConfig) { requireNonNull(types, "types is null"); @@ -120,12 +136,13 @@ public TypeRegistry(Set types) addType(COLOR); addType(JSON); addType(CODE_POINTS); + addType(LIST_LITERAL); addParametricType(VarcharParametricType.VARCHAR); addParametricType(CharParametricType.CHAR); addParametricType(DecimalParametricType.DECIMAL); addParametricType(ROW); addParametricType(ARRAY); - addParametricType(MAP); + addParametricType(new MapParametricType(featuresConfig.isNewMapBlock())); addParametricType(FUNCTION); for (Type type : types) { @@ -133,6 +150,12 @@ public TypeRegistry(Set types) } } + public void setFunctionRegistry(FunctionRegistry functionRegistry) + { + checkState(this.functionRegistry == null, "TypeRegistry can only be associated with a single FunctionRegistry"); + this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null"); + } + @Override public Type getType(TypeSignature signature) { @@ -167,7 +190,7 @@ private Type instantiateParametricType(TypeSignature signature) } try { - Type instantiatedType = parametricType.createType(parameters); + Type instantiatedType = parametricType.createType(this, parameters); // TODO: reimplement this check? Currently "varchar(Integer.MAX_VALUE)" fails with "varchar" //checkState(instantiatedType.equalsSignature(signature), "Instantiated parametric type name (%s) does not match expected name (%s)", instantiatedType, signature); @@ -523,6 +546,14 @@ public Optional coerceTypeBase(Type sourceType, String resultTypeBase) return Optional.empty(); } } + case StandardTypes.ARRAY: { + switch (resultTypeBase) { + case ListLiteralType.NAME: + return Optional.of(LIST_LITERAL); + default: + return Optional.empty(); + } + } default: return Optional.empty(); } @@ -538,4 +569,11 @@ public static boolean isCovariantTypeBase(String typeBase) { return typeBase.equals(StandardTypes.ARRAY) || typeBase.equals(StandardTypes.MAP); } + + @Override + public MethodHandle resolveOperator(OperatorType operatorType, List argumentTypes) + { + requireNonNull(functionRegistry, "functionRegistry is null"); + return functionRegistry.getScalarFunctionImplementation(functionRegistry.resolveOperator(operatorType, argumentTypes)).getMethodHandle(); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/type/VarcharParametricType.java b/presto-main/src/main/java/com/facebook/presto/type/VarcharParametricType.java index 1976b4a8153dc..64db709d8e482 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/VarcharParametricType.java +++ b/presto-main/src/main/java/com/facebook/presto/type/VarcharParametricType.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; import com.facebook.presto.spi.type.VarcharType; @@ -35,7 +36,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { if (parameters.isEmpty()) { return createUnboundedVarcharType(); diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 0081dd1ba1303..6f196db804860 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -18,6 +18,7 @@ import com.facebook.presto.sql.planner.SubPlan; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; @@ -28,6 +29,7 @@ import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; @@ -52,7 +54,6 @@ import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.base.Joiner; import com.google.common.base.Preconditions; @@ -318,12 +319,12 @@ public Void visitExchange(ExchangeNode node, Void context) public Void visitAggregation(AggregationNode node, Void context) { StringBuilder builder = new StringBuilder(); - for (Map.Entry entry : node.getAggregations().entrySet()) { - if (node.getMasks().containsKey(entry.getKey())) { - builder.append(format("%s := %s (mask = %s)\\n", entry.getKey(), entry.getValue(), node.getMasks().get(entry.getKey()))); + for (Map.Entry entry : node.getAggregations().entrySet()) { + if (entry.getValue().getMask().isPresent()) { + builder.append(format("%s := %s (mask = %s)\\n", entry.getKey(), entry.getValue().getCall(), entry.getValue().getMask().get())); } else { - builder.append(format("%s := %s\\n", entry.getKey(), entry.getValue())); + builder.append(format("%s := %s\\n", entry.getKey(), entry.getValue().getCall())); } } printNode(node, format("Aggregate[%s]", node.getStep()), builder.toString(), NODE_COLORS.get(NodeType.AGGREGATE)); @@ -437,9 +438,7 @@ public Void visitJoin(JoinNode node, Void context) { List joinExpressions = new ArrayList<>(); for (JoinNode.EquiJoinClause clause : node.getCriteria()) { - joinExpressions.add(new ComparisonExpression(ComparisonExpressionType.EQUAL, - clause.getLeft().toSymbolReference(), - clause.getRight().toSymbolReference())); + joinExpressions.add(clause.toExpression()); } String criteria = Joiner.on(" AND ").join(joinExpressions); @@ -483,6 +482,18 @@ public Void visitAssignUniqueId(AssignUniqueId node, Void context) return null; } + @Override + public Void visitLateralJoin(LateralJoinNode node, Void context) + { + String parameters = Joiner.on(",").join(node.getCorrelation()); + printNode(node, "LateralJoin", parameters, NODE_COLORS.get(NodeType.JOIN)); + + node.getInput().accept(this, context); + node.getSubquery().accept(this, context); + + return null; + } + @Override public Void visitIndexSource(IndexSourceNode node, Void context) { diff --git a/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java b/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java index d3aa40100f704..d2d7e0ec9cf3f 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java @@ -16,13 +16,13 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.facebook.presto.type.UnknownType; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; diff --git a/presto-main/src/main/java/com/facebook/presto/util/MoreSets.java b/presto-main/src/main/java/com/facebook/presto/util/MoreLists.java similarity index 58% rename from presto-main/src/main/java/com/facebook/presto/util/MoreSets.java rename to presto-main/src/main/java/com/facebook/presto/util/MoreLists.java index 023b044c01249..e8f75f5aaa457 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/MoreSets.java +++ b/presto-main/src/main/java/com/facebook/presto/util/MoreLists.java @@ -13,26 +13,21 @@ */ package com.facebook.presto.util; -import com.google.common.collect.Iterables; -import com.google.common.collect.Sets; +import com.google.common.collect.ImmutableList; -import java.util.Set; +import java.util.List; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; -public class MoreSets +public class MoreLists { - private MoreSets() {} - - public static Set newIdentityHashSet() + public static List> listOfListsCopy(List> lists) { - return Sets.newIdentityHashSet(); + return requireNonNull(lists, "lists is null").stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); } - public static Set newIdentityHashSet(Iterable elements) - { - Set set = newIdentityHashSet(); - Iterables.addAll(set, requireNonNull(elements, "elements cannot be null")); - return set; - } + private MoreLists() {} } diff --git a/presto-main/src/main/java/com/facebook/presto/util/MoreMaps.java b/presto-main/src/main/java/com/facebook/presto/util/MoreMaps.java index 153f79528e725..64c59b5515459 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/MoreMaps.java +++ b/presto-main/src/main/java/com/facebook/presto/util/MoreMaps.java @@ -26,7 +26,12 @@ private MoreMaps() {} public static Map mergeMaps(Map map1, Map map2, BinaryOperator merger) { - return Stream.of(map1, map2) + return mergeMaps(Stream.of(map1, map2), merger); + } + + public static Map mergeMaps(Stream> mapStream, BinaryOperator merger) + { + return mapStream .map(Map::entrySet) .flatMap(Collection::stream) .collect(toMap(Map.Entry::getKey, Map.Entry::getValue, merger)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/Predicates.java b/presto-main/src/main/java/com/facebook/presto/util/MorePredicates.java similarity index 71% rename from presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/Predicates.java rename to presto-main/src/main/java/com/facebook/presto/util/MorePredicates.java index 53d21d881f4b4..5e8cc3345248d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/Predicates.java +++ b/presto-main/src/main/java/com/facebook/presto/util/MorePredicates.java @@ -11,30 +11,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.optimizations; +package com.facebook.presto.util; import java.util.function.Predicate; -public class Predicates +import static com.google.common.base.Predicates.alwaysFalse; + +public class MorePredicates { - private Predicates() {} + private MorePredicates() {} public static Predicate isInstanceOfAny(Class... classes) { - Predicate predicate = alwaysFalse(); + Predicate predicate = alwaysFalse(); for (Class clazz : classes) { predicate = predicate.or(clazz::isInstance); } return predicate; } - - public static Predicate alwaysTrue() - { - return x -> true; - } - - public static Predicate alwaysFalse() - { - return x -> false; - } } diff --git a/presto-main/src/main/java/com/facebook/presto/util/Reflection.java b/presto-main/src/main/java/com/facebook/presto/util/Reflection.java index b817798c37908..841bf5cb58a2e 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/Reflection.java +++ b/presto-main/src/main/java/com/facebook/presto/util/Reflection.java @@ -18,6 +18,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; +import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; @@ -47,6 +48,14 @@ public static Method method(Class clazz, String name, Class... parameterTy } } + /** + * Returns a MethodHandle corresponding to the specified method. + *

+ * Warning: The way Oracle JVM implements producing MethodHandle for a method involves creating + * JNI global weak references. G1 processes such references serially. As a result, calling this + * method in a tight loop can create significant GC pressure and significantly increase + * application pause time. + */ public static MethodHandle methodHandle(Class clazz, String name, Class... parameterTypes) { try { @@ -57,16 +66,58 @@ public static MethodHandle methodHandle(Class clazz, String name, Class... } } - public static MethodHandle methodHandle(Method method) + /** + * Returns a MethodHandle corresponding to the specified method. + *

+ * Warning: The way Oracle JVM implements producing MethodHandle for a method involves creating + * JNI global weak references. G1 processes such references serially. As a result, calling this + * method in a tight loop can create significant GC pressure and significantly increase + * application pause time. + */ + public static MethodHandle methodHandle(StandardErrorCode errorCode, Method method) { try { return MethodHandles.lookup().unreflect(method); } catch (IllegalAccessException e) { - throw new PrestoException(GENERIC_INTERNAL_ERROR, e); + throw new PrestoException(errorCode, e); } } + /** + * Returns a MethodHandle corresponding to the specified method. + *

+ * Warning: The way Oracle JVM implements producing MethodHandle for a method involves creating + * JNI global weak references. G1 processes such references serially. As a result, calling this + * method in a tight loop can create significant GC pressure and significantly increase + * application pause time. + */ + public static MethodHandle methodHandle(Method method) + { + return methodHandle(GENERIC_INTERNAL_ERROR, method); + } + + /** + * Returns a MethodHandle corresponding to the specified constructor. + *

+ * Warning: The way Oracle JVM implements producing MethodHandle for a constructor involves + * creating JNI global weak references. G1 processes such references serially. As a result, + * calling this method in a tight loop can create significant GC pressure and significantly + * increase application pause time. + */ + public static MethodHandle constructorMethodHandle(Class clazz, Class... parameterTypes) + { + return constructorMethodHandle(GENERIC_INTERNAL_ERROR, clazz, parameterTypes); + } + + /** + * Returns a MethodHandle corresponding to the specified constructor. + *

+ * Warning: The way Oracle JVM implements producing MethodHandle for a constructor involves + * creating JNI global weak references. G1 processes such references serially. As a result, + * calling this method in a tight loop can create significant GC pressure and significantly + * increase application pause time. + */ public static MethodHandle constructorMethodHandle(StandardErrorCode errorCode, Class clazz, Class... parameterTypes) { try { @@ -76,4 +127,22 @@ public static MethodHandle constructorMethodHandle(StandardErrorCode errorCode, throw new PrestoException(errorCode, e); } } + + /** + * Returns a MethodHandle corresponding to the specified constructor. + *

+ * Warning: The way Oracle JVM implements producing MethodHandle for a constructor involves + * creating JNI global weak references. G1 processes such references serially. As a result, + * calling this method in a tight loop can create significant GC pressure and significantly + * increase application pause time. + */ + public static MethodHandle constructorMethodHandle(StandardErrorCode errorCode, Constructor constructor) + { + try { + return MethodHandles.lookup().unreflectConstructor(constructor); + } + catch (IllegalAccessException e) { + throw new PrestoException(errorCode, e); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/util/maps/IdentityLinkedHashMap.java b/presto-main/src/main/java/com/facebook/presto/util/maps/IdentityLinkedHashMap.java deleted file mode 100644 index d009a7607c3ad..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/util/maps/IdentityLinkedHashMap.java +++ /dev/null @@ -1,315 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.util.maps; - -import com.google.common.base.Equivalence; -import com.google.common.collect.Iterators; - -import java.util.AbstractMap; -import java.util.Collection; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Set; - -import static java.util.stream.Collectors.joining; - -public class IdentityLinkedHashMap - implements Map -{ - private final Map, V> delegate = new LinkedHashMap<>(); - private final Equivalence equivalence = Equivalence.identity(); - - public IdentityLinkedHashMap() - { - } - - public IdentityLinkedHashMap(IdentityLinkedHashMap map) - { - putAll(map); - } - - @Override - public int size() - { - return delegate.size(); - } - - @Override - public boolean isEmpty() - { - return delegate.isEmpty(); - } - - @Override - public boolean containsKey(Object key) - { - return delegate.containsKey(equivalence.wrap(key)); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean containsValue(Object value) - { - // should use identity-based comparison - throw new UnsupportedOperationException(); - } - - @Override - public V get(Object key) - { - return delegate.get(equivalence.wrap(key)); - } - - @Override - public V put(K key, V value) - { - return delegate.put(equivalence.wrap(key), value); - } - - @Override - public V remove(Object key) - { - return delegate.remove(equivalence.wrap(key)); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean remove(Object key, Object value) - { - // should use identity-based comparison for value too - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean replace(K key, V oldValue, V newValue) - { - // should use identity-based comparison for value too - throw new UnsupportedOperationException(); - } - - @Override - public void putAll(Map map) - { - map.entrySet().forEach(e -> delegate.put(equivalence.wrap(e.getKey()), e.getValue())); - } - - @Override - public void clear() - { - delegate.clear(); - } - - @Override - public IterateOnlySetView keySet() - { - return new IterateOnlySetView() - { - @Override - public Iterator iterator() - { - return delegate.keySet().stream().map(Equivalence.Wrapper::get).iterator(); - } - }; - } - - @Override - public IterateOnlyCollectionView values() - { - return new IterateOnlyCollectionView() - { - @Override - public Iterator iterator() - { - return Iterators.unmodifiableIterator(delegate.values().iterator()); - } - }; - } - - @Override - public IterateOnlySetView> entrySet() - { - return new IterateOnlySetView>() - { - @Override - public Iterator> iterator() - { - return delegate.entrySet().stream().map(e -> { - K key = e.getKey().get(); - return (Entry) new AbstractMap.SimpleEntry<>(key, e.getValue()); - }).iterator(); - } - }; - } - - public abstract class IterateOnlySetView - extends IterateOnlyCollectionView - implements Set - { - private IterateOnlySetView() {} - } - - public abstract class IterateOnlyCollectionView - implements Collection - { - private IterateOnlyCollectionView() {} - - @Override - public int size() - { - return IdentityLinkedHashMap.this.size(); - } - - @Override - public boolean isEmpty() - { - return size() == 0; - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean contains(Object item) - { - // should use identity-based comparison whenever map's keys or values are compared - throw new UnsupportedOperationException(); - } - - @Override - public Object[] toArray() - { - return Iterators.toArray(iterator(), Object.class); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public T[] toArray(T[] array) - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean add(E item) - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean remove(Object item) - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean containsAll(Collection other) - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean addAll(Collection other) - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean retainAll(Collection other) - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean removeAll(Collection other) - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public void clear() - { - throw new UnsupportedOperationException(); - } - - @Override - public String toString() - { - return stream() - .map(String::valueOf) - .collect(joining(", ", "[", "]")); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public int hashCode() - { - throw new UnsupportedOperationException(); - } - - /** - * @deprecated Unsupported operation. - */ - @Deprecated - @Override - public boolean equals(Object obj) - { - throw new UnsupportedOperationException(); - } - } -} diff --git a/presto-main/src/main/resources/webapp/assets/cluster-hud.js b/presto-main/src/main/resources/webapp/assets/cluster-hud.js index a69d8bda974f8..4a9ee22024f45 100644 --- a/presto-main/src/main/resources/webapp/assets/cluster-hud.js +++ b/presto-main/src/main/resources/webapp/assets/cluster-hud.js @@ -12,7 +12,7 @@ * limitations under the License. */ -var SPARKLINE_PROPERTIES = { +const SPARKLINE_PROPERTIES = { width:'100%', height: '75px', fillColor:'#3F4552', @@ -20,9 +20,9 @@ var SPARKLINE_PROPERTIES = { spotColor: '#1EDCFF', tooltipClassname: 'sparkline-tooltip', disableHiddenCheck: true, -} +}; -var ClusterHUD = React.createClass({ +let ClusterHUD = React.createClass({ getInitialState: function() { return { runningQueries: [], @@ -33,9 +33,14 @@ var ClusterHUD = React.createClass({ reservedMemory: [], rowInputRate: [], byteInputRate: [], - cpuTimeRate: [], + perWorkerCpuTimeRate: [], lastRender: null, + lastRefresh: null, + + lastInputRows: null, + lastInputBytes: null, + lastCpuTime: null, initialized: false, }; @@ -43,13 +48,28 @@ var ClusterHUD = React.createClass({ resetTimer: function() { clearTimeout(this.timeoutId); // stop refreshing when query finishes or fails - if (this.state.query == null || !this.state.ended) { + if (this.state.query === null || !this.state.ended) { this.timeoutId = setTimeout(this.refreshLoop, 1000); } }, refreshLoop: function() { clearTimeout(this.timeoutId); // to stop multiple series of refreshLoop from going on simultaneously $.get('/v1/cluster', function (clusterState) { + + let newRowInputRate = []; + let newByteInputRate = []; + let newPerWorkerCpuTimeRate = []; + if (this.state.lastRefresh !== null) { + const rowsInputSinceRefresh = clusterState.totalInputRows - this.state.lastInputRows; + const bytesInputSinceRefresh = clusterState.totalInputBytes - this.state.lastInputBytes; + const cpuTimeSinceRefresh = clusterState.totalCpuTimeSecs - this.state.lastCpuTime; + const secsSinceRefresh = (Date.now() - this.state.lastRefresh) / 1000.0; + + newRowInputRate = addExponentiallyWeightedToHistory(rowsInputSinceRefresh / secsSinceRefresh, this.state.rowInputRate); + newByteInputRate = addExponentiallyWeightedToHistory(bytesInputSinceRefresh / secsSinceRefresh, this.state.byteInputRate); + newPerWorkerCpuTimeRate = addExponentiallyWeightedToHistory((cpuTimeSinceRefresh / clusterState.activeWorkers) / secsSinceRefresh, this.state.perWorkerCpuTimeRate); + } + this.setState({ // instantaneous stats runningQueries: addToHistory(clusterState.runningQueries, this.state.runningQueries), @@ -60,11 +80,19 @@ var ClusterHUD = React.createClass({ // moving averages runningDrivers: addExponentiallyWeightedToHistory(clusterState.runningDrivers, this.state.runningDrivers), reservedMemory: addExponentiallyWeightedToHistory(clusterState.reservedMemory, this.state.reservedMemory), - rowInputRate: addExponentiallyWeightedToHistory(clusterState.rowInputRate, this.state.rowInputRate), - byteInputRate: addExponentiallyWeightedToHistory(clusterState.byteInputRate, this.state.byteInputRate), - cpuTimeRate: addExponentiallyWeightedToHistory(clusterState.cpuTimeRate, this.state.cpuTimeRate), + + // moving averages for diffs + rowInputRate: newRowInputRate, + byteInputRate: newByteInputRate, + perWorkerCpuTimeRate: newPerWorkerCpuTimeRate, + + lastInputRows: clusterState.totalInputRows, + lastInputBytes: clusterState.totalInputBytes, + lastCpuTime: clusterState.totalCpuTimeSecs, initialized: true, + + lastRefresh: Date.now() }); this.resetTimer(); }.bind(this)) @@ -77,8 +105,8 @@ var ClusterHUD = React.createClass({ }, componentDidUpdate: function() { // prevent multiple calls to componentDidUpdate (resulting from calls to setState or otherwise) within the refresh interval from re-rendering sparklines/charts - if (this.state.lastRender == null || (Date.now() - this.state.lastRender) >= 1000) { - var renderTimestamp = Date.now(); + if (this.state.lastRender === null || (Date.now() - this.state.lastRender) >= 1000) { + const renderTimestamp = Date.now(); $('#running-queries-sparkline').sparkline(this.state.runningQueries, $.extend({}, SPARKLINE_PROPERTIES, {chartRangeMin: 0})); $('#blocked-queries-sparkline').sparkline(this.state.blockedQueries, $.extend({}, SPARKLINE_PROPERTIES, {chartRangeMin: 0})); $('#queued-queries-sparkline').sparkline(this.state.queuedQueries, $.extend({}, SPARKLINE_PROPERTIES, {chartRangeMin: 0})); @@ -89,7 +117,7 @@ var ClusterHUD = React.createClass({ $('#row-input-rate-sparkline').sparkline(this.state.rowInputRate, $.extend({}, SPARKLINE_PROPERTIES, {numberFormatter: formatCount})); $('#byte-input-rate-sparkline').sparkline(this.state.byteInputRate, $.extend({}, SPARKLINE_PROPERTIES, {numberFormatter: formatDataSizeBytes})); - $('#cpu-time-rate-sparkline').sparkline(this.state.cpuTimeRate, $.extend({}, SPARKLINE_PROPERTIES, {numberFormatter: precisionRound})); + $('#cpu-time-rate-sparkline').sparkline(this.state.perWorkerCpuTimeRate, $.extend({}, SPARKLINE_PROPERTIES, {numberFormatter: precisionRound})); this.setState({ lastRender: renderTimestamp @@ -216,8 +244,8 @@ var ClusterHUD = React.createClass({
- - Parallelism + + Worker Parallelism
@@ -242,7 +270,7 @@ var ClusterHUD = React.createClass({
- { formatCount(this.state.cpuTimeRate[this.state.cpuTimeRate.length - 1]) } + { formatCount(this.state.perWorkerCpuTimeRate[this.state.perWorkerCpuTimeRate.length - 1]) }
Loading ...
diff --git a/presto-main/src/main/resources/webapp/assets/plan.js b/presto-main/src/main/resources/webapp/assets/plan.js index 9a889e33719e0..39601fa1ef8bf 100644 --- a/presto-main/src/main/resources/webapp/assets/plan.js +++ b/presto-main/src/main/resources/webapp/assets/plan.js @@ -107,7 +107,7 @@ let LivePlan = React.createClass({ resetTimer: function() { clearTimeout(this.timeoutId); // stop refreshing when query finishes or fails - if (this.state.query == null || !this.state.ended) { + if (this.state.query === null || !this.state.ended) { this.timeoutId = setTimeout(this.refreshLoop, 1000); } }, @@ -148,7 +148,7 @@ let LivePlan = React.createClass({ }, refreshLoop: function() { clearTimeout(this.timeoutId); // to stop multiple series of refreshLoop from going on simultaneously - const queryId = window.location.search.substring(1); + const queryId = getFirstParameter(window.location.search); $.get('/v1/query/' + queryId, function (query) { this.setState({ query: query, @@ -195,7 +195,7 @@ let LivePlan = React.createClass({ graph.setEdge("node-" + source, nodeId, {arrowheadStyle: "fill: #fff; stroke-width: 0;"}); }); - if (node.type == 'remoteSource') { + if (node.type === 'remoteSource') { graph.setNode(nodeId, {label: '', shape: "circle"}); node.remoteSources.forEach(sourceId => { @@ -223,7 +223,7 @@ let LivePlan = React.createClass({ svg.attr("width", graph.graph().width); }, findStage: function (stageId, currentStage) { - if (stageId == -1) { + if (stageId === -1) { return null; } @@ -233,7 +233,7 @@ let LivePlan = React.createClass({ for (let i = 0; i < currentStage.subStages.length; i++) { const stage = this.findStage(stageId, currentStage.subStages[i]); - if (stage != null) { + if (stage !== null) { return stage; } } @@ -243,7 +243,7 @@ let LivePlan = React.createClass({ render: function() { const query = this.state.query; - if (query == null || this.state.initialized == false) { + if (query === null || this.state.initialized === false) { let label = (
Loading...
); if (this.state.initialized) { label = "Query not found"; @@ -314,7 +314,6 @@ let LivePlan = React.createClass({ { this.renderProgressBar() }
-
{ livePlanGraph } diff --git a/presto-main/src/main/resources/webapp/assets/presto.css b/presto-main/src/main/resources/webapp/assets/presto.css index 3a38d76744c8e..928bc6809cfb6 100644 --- a/presto-main/src/main/resources/webapp/assets/presto.css +++ b/presto-main/src/main/resources/webapp/assets/presto.css @@ -124,6 +124,10 @@ pre { font-style: italic; } +.uppercase { + text-transform: uppercase; +} + .h2-hr { margin: 0; border-top: 2px solid #fff; @@ -166,7 +170,6 @@ pre { padding: 15px; position: relative; color: #666; - text-transform: uppercase; font-size: 11px; } @@ -222,6 +225,27 @@ pre { color: #fff; } +.status-light { + border-radius: 50%; + text-shadow: 0 0 3px #c3b300; + color: #c3b300; + font-size: 16px; +} + +.status-light-red { + color: #ff0100; + text-shadow: 0 0 3px #ff0100; +} + +.status-light-green { + color: #1cff00; + text-shadow: 0 0 3px #1cff00; +} + +.status-light:before { + content: '\25CF'; +} + /** ===================== **/ /** Cluster Overview Page **/ /** ===================== **/ diff --git a/presto-main/src/main/resources/webapp/assets/query-list.js b/presto-main/src/main/resources/webapp/assets/query-list.js index ed458496ddf0a..1043bc08b4e0c 100644 --- a/presto-main/src/main/resources/webapp/assets/query-list.js +++ b/presto-main/src/main/resources/webapp/assets/query-list.js @@ -12,31 +12,31 @@ * limitations under the License. */ - var QueryListItem = React.createClass({ + let QueryListItem = React.createClass({ formatQueryText: function(queryText) { - var lines = queryText.split("\n"); - var minLeadingWhitespace = -1; - for (var i = 0; i < lines.length; i++) { - if (minLeadingWhitespace == 0) { + const lines = queryText.split("\n"); + let minLeadingWhitespace = -1; + for (let i = 0; i < lines.length; i++) { + if (minLeadingWhitespace === 0) { break; } - if (lines[i].trim().length == 0) { + if (lines[i].trim().length === 0) { continue; } - var leadingWhitespace = lines[i].search(/\S/); + const leadingWhitespace = lines[i].search(/\S/); - if (leadingWhitespace > -1 && ((leadingWhitespace < minLeadingWhitespace) || minLeadingWhitespace == -1)) { + if (leadingWhitespace > -1 && ((leadingWhitespace < minLeadingWhitespace) || minLeadingWhitespace === -1)) { minLeadingWhitespace = leadingWhitespace; } } - var formattedQueryText = ""; + let formattedQueryText = ""; - for (i = 0; i < lines.length; i++) { - var trimmedLine = lines[i].substring(minLeadingWhitespace).replace(/\s+$/g, ''); + for (let i = 0; i < lines.length; i++) { + const trimmedLine = lines[i].substring(minLeadingWhitespace).replace(/\s+$/g, ''); if (trimmedLine.length > 0) { formattedQueryText += trimmedLine; @@ -51,61 +51,61 @@ }, render: function() { - var query = this.props.query; - var progressBarStyle = { width: getProgressBarPercentage(query) + "%", backgroundColor: getQueryStateColor(query) }; + const query = this.props.query; + const progressBarStyle = {width: getProgressBarPercentage(query) + "%", backgroundColor: getQueryStateColor(query)}; - var splitDetails = ( + const splitDetails = (
-    +    { query.queryStats.completedDrivers } -    - { (query.state == "FINISHED" || query.state == "FAILED") ? 0 : query.queryStats.runningDrivers } +    + { (query.state === "FINISHED" || query.state === "FAILED") ? 0 : query.queryStats.runningDrivers } -    - { (query.state == "FINISHED" || query.state == "FAILED") ? 0 : query.queryStats.queuedDrivers } +    + { (query.state === "FINISHED" || query.state === "FAILED") ? 0 : query.queryStats.queuedDrivers }
); - var timingDetails = ( + const timingDetails = (
-    +    { query.queryStats.executionTime } -    +    { query.queryStats.elapsedTime } - -    + +    { query.queryStats.totalCpuTime }
); - var memoryDetails = ( + const memoryDetails = (
-    +    { query.queryStats.totalMemoryReservation } -    +    { query.queryStats.peakMemoryReservation } - -    + +    { formatDataSizeBytes(query.queryStats.cumulativeMemory) }
); - var user = ({ query.session.user }); + let user = ({ query.session.user }); if (query.session.principal) { user = ( - { query.session.user } + { query.session.user } ); } @@ -124,7 +124,7 @@
-    +    { truncateString(user, 35) }
@@ -132,7 +132,7 @@
-    +    { truncateString(query.session.source, 35) }
@@ -169,12 +169,12 @@ } }); -var DisplayedQueriesList = React.createClass({ +let DisplayedQueriesList = React.createClass({ render: function() { - var queryNodes = this.props.queries.map(function (query) { + const queryNodes = this.props.queries.map(function (query) { return ( - + ); }.bind(this)); return ( @@ -185,32 +185,32 @@ var DisplayedQueriesList = React.createClass({ } }); -var FILTER_TYPE = { - RUNNING_BLOCKED: function(query) { - return query.state == "PLANNING" || query.state == "STARTING" || query.state == "RUNNING" || query.state == "FINISHING"; +const FILTER_TYPE = { + RUNNING_BLOCKED: function (query) { + return query.state === "PLANNING" || query.state === "STARTING" || query.state === "RUNNING" || query.state === "FINISHING"; }, - QUEUED: function(query) { return query.state == "QUEUED"}, - FINISHED: function(query) { return query.state == "FINISHED"}, - FAILED: function(query) { return query.state == "FAILED" && query.errorType != "USER_ERROR"}, - USER_ERROR: function(query) { return query.state == "FAILED" && query.errorType == "USER_ERROR"}, + QUEUED: function (query) { return query.state === "QUEUED"}, + FINISHED: function (query) { return query.state === "FINISHED"}, + FAILED: function (query) { return query.state === "FAILED" && query.errorType !== "USER_ERROR"}, + USER_ERROR: function (query) { return query.state === "FAILED" && query.errorType === "USER_ERROR"}, }; -var SORT_TYPE = { - CREATED: function(query) {return Date.parse(query.queryStats.createTime)}, - ELAPSED: function(query) {return parseDuration(query.queryStats.elapsedTime)}, - EXECUTION: function(query) {return parseDuration(query.queryStats.executionTime)}, - CPU: function(query) {return parseDuration(query.queryStats.totalCpuTime)}, - CUMULATIVE_MEMORY: function(query) {return query.queryStats.cumulativeMemory}, - CURRENT_MEMORY: function(query) {return parseDataSize(query.queryStats.totalMemoryReservation)}, +const SORT_TYPE = { + CREATED: function (query) {return Date.parse(query.queryStats.createTime)}, + ELAPSED: function (query) {return parseDuration(query.queryStats.elapsedTime)}, + EXECUTION: function (query) {return parseDuration(query.queryStats.executionTime)}, + CPU: function (query) {return parseDuration(query.queryStats.totalCpuTime)}, + CUMULATIVE_MEMORY: function (query) {return query.queryStats.cumulativeMemory}, + CURRENT_MEMORY: function (query) {return parseDataSize(query.queryStats.totalMemoryReservation)}, }; -var SORT_ORDER = { - ASCENDING: function(value) {return value}, - DESCENDING: function(value) {return -value} +const SORT_ORDER = { + ASCENDING: function (value) {return value}, + DESCENDING: function (value) {return -value} }; -var QueryList = React.createClass({ - getInitialState: function() { +let QueryList = React.createClass({ + getInitialState: function () { return { allQueries: [], displayedQueries: [], @@ -222,20 +222,21 @@ var QueryList = React.createClass({ maxQueries: 100, lastRefresh: Date.now(), lastReorder: Date.now(), - initialized: false}; + initialized: false + }; }, - sortAndLimitQueries: function(queries, sortType, sortOrder, maxQueries) { - queries.sort(function(queryA, queryB) { + sortAndLimitQueries: function (queries, sortType, sortOrder, maxQueries) { + queries.sort(function (queryA, queryB) { return sortOrder(sortType(queryA) - sortType(queryB)); }, this); - if (maxQueries != 0 && queries.length > maxQueries) { + if (maxQueries !== 0 && queries.length > maxQueries) { queries.splice(maxQueries, (queries.length - maxQueries)); } }, - filterQueries: function(queries, filters, searchString) { - var stateFilteredQueries = queries.filter(function(query) { - for (var i = 0; i < filters.length; i++) { + filterQueries: function (queries, filters, searchString) { + const stateFilteredQueries = queries.filter(function (query) { + for (let i = 0; i < filters.length; i++) { if (filters[i](query)) { return true; } @@ -243,47 +244,47 @@ var QueryList = React.createClass({ return false; }); - if (searchString == '') { + if (searchString === '') { return stateFilteredQueries; } else { - return stateFilteredQueries.filter(function(query) { - var term = searchString.toLowerCase(); - if (query.queryId.toLowerCase().indexOf(term) != -1 || - getHumanReadableState(query).toLowerCase().indexOf(term) != -1 || - query.query.toLowerCase().indexOf(term) != -1) { + return stateFilteredQueries.filter(function (query) { + const term = searchString.toLowerCase(); + if (query.queryId.toLowerCase().indexOf(term) !== -1 || + getHumanReadableState(query).toLowerCase().indexOf(term) !== -1 || + query.query.toLowerCase().indexOf(term) !== -1) { return true; } - if (query.session.user && query.session.user.toLowerCase().indexOf(term) != -1) { + if (query.session.user && query.session.user.toLowerCase().indexOf(term) !== -1) { return true; } - if (query.session.source && query.session.source.toLowerCase().indexOf(term) != -1) { + if (query.session.source && query.session.source.toLowerCase().indexOf(term) !== -1) { return true; } }, this); } }, - resetTimer: function() { + resetTimer: function () { clearTimeout(this.timeoutId); // stop refreshing when query finishes or fails - if (this.state.query == null || !this.state.ended) { + if (this.state.query === null || !this.state.ended) { this.timeoutId = setTimeout(this.refreshLoop, 1000); } }, - refreshLoop: function() { + refreshLoop: function () { clearTimeout(this.timeoutId); // to stop multiple series of refreshLoop from going on simultaneously clearTimeout(this.searchTimeoutId); $.get('/v1/query', function (queryList) { - var queryMap = queryList.reduce(function(map, query) { + const queryMap = queryList.reduce(function (map, query) { map[query.queryId] = query; return map; }, {}); - var updatedQueries = []; + let updatedQueries = []; this.state.displayedQueries.forEach(function (oldQuery) { if (oldQuery.queryId in queryMap) { updatedQueries.push(queryMap[oldQuery.queryId]); @@ -291,18 +292,18 @@ var QueryList = React.createClass({ } }); - var newQueries = []; - for (var queryId in queryMap) { + let newQueries = []; + for (const queryId in queryMap) { if (queryMap[queryId]) { newQueries.push(queryMap[queryId]); } } newQueries = this.filterQueries(newQueries, this.state.filters, this.state.searchString); - var lastRefresh = Date.now(); - var lastReorder = this.state.lastReorder; + const lastRefresh = Date.now(); + let lastReorder = this.state.lastReorder; - if (this.state.reorderInterval != 0 && ((lastRefresh - lastReorder) >= this.state.reorderInterval)) { + if (this.state.reorderInterval !== 0 && ((lastRefresh - lastReorder) >= this.state.reorderInterval)) { updatedQueries = this.filterQueries(updatedQueries, this.state.filters, this.state.searchString); updatedQueries = updatedQueries.concat(newQueries); this.sortAndLimitQueries(updatedQueries, this.state.currentSortType, this.state.currentSortOrder, 0); @@ -313,7 +314,7 @@ var QueryList = React.createClass({ updatedQueries = updatedQueries.concat(newQueries); } - if (this.state.maxQueries != 0 && (updatedQueries.length > this.state.maxQueries)) { + if (this.state.maxQueries !== 0 && (updatedQueries.length > this.state.maxQueries)) { updatedQueries.splice(this.state.maxQueries, (updatedQueries.length - this.state.maxQueries)); } @@ -326,18 +327,18 @@ var QueryList = React.createClass({ }); this.resetTimer(); }.bind(this)) - .error(function() { - this.setState({ - initialized: true, - }) - this.resetTimer(); - }.bind(this)); + .error(function () { + this.setState({ + initialized: true, + }); + this.resetTimer(); + }.bind(this)); }, - componentDidMount: function() { + componentDidMount: function () { this.refreshLoop(); }, - handleSearchStringChange: function(event) { - var newSearchString = event.target.value; + handleSearchStringChange: function (event) { + const newSearchString = event.target.value; clearTimeout(this.searchTimeoutId); this.setState({ @@ -346,23 +347,24 @@ var QueryList = React.createClass({ this.searchTimeoutId = setTimeout(this.executeSearch, 200); }, - executeSearch: function() { + executeSearch: function () { clearTimeout(this.searchTimeoutId); - var newDisplayedQueries = this.filterQueries(this.state.allQueries, this.state.filters, this.state.searchString); + const newDisplayedQueries = this.filterQueries(this.state.allQueries, this.state.filters, this.state.searchString); this.sortAndLimitQueries(newDisplayedQueries, this.state.currentSortType, this.state.currentSortOrder, this.state.maxQueries); this.setState({ displayedQueries: newDisplayedQueries }); }, - renderMaxQueriesListItem: function(maxQueries, maxQueriesText) { + renderMaxQueriesListItem: function (maxQueries, maxQueriesText) { return ( -
  • { maxQueriesText }
  • +
  • { maxQueriesText } +
  • ); }, - handleMaxQueriesClick: function(newMaxQueries) { - var filteredQueries = this.filterQueries(this.state.allQueries, this.state.filters, this.state.searchString); + handleMaxQueriesClick: function (newMaxQueries) { + const filteredQueries = this.filterQueries(this.state.allQueries, this.state.filters, this.state.searchString); this.sortAndLimitQueries(filteredQueries, this.state.currentSortType, this.state.currentSortOrder, newMaxQueries); this.setState({ @@ -370,21 +372,22 @@ var QueryList = React.createClass({ displayedQueries: filteredQueries }); }, - renderReorderListItem: function(interval, intervalText) { + renderReorderListItem: function (interval, intervalText) { return ( -
  • { intervalText }
  • +
  • { intervalText }
  • ); }, - handleReorderClick: function(interval) { - if (this.state.reorderInterval != interval) { + handleReorderClick: function (interval) { + if (this.state.reorderInterval !== interval) { this.setState({ reorderInterval: interval, }); } }, - renderSortListItem: function(sortType, sortText) { - if (this.state.currentSortType == sortType) { - var directionArrow = this.state.currentSortOrder == SORT_ORDER.ASCENDING ? : ; + renderSortListItem: function (sortType, sortText) { + if (this.state.currentSortType === sortType) { + const directionArrow = this.state.currentSortOrder === SORT_ORDER.ASCENDING ? : + ; return (
  • @@ -401,15 +404,15 @@ var QueryList = React.createClass({
  • ); } }, - handleSortClick: function(sortType) { - var newSortType = sortType; - var newSortOrder = SORT_ORDER.DESCENDING; + handleSortClick: function (sortType) { + const newSortType = sortType; + let newSortOrder = SORT_ORDER.DESCENDING; - if (this.state.currentSortType == sortType && this.state.currentSortOrder == SORT_ORDER.DESCENDING) { + if (this.state.currentSortType === sortType && this.state.currentSortOrder === SORT_ORDER.DESCENDING) { newSortOrder = SORT_ORDER.ASCENDING; } - var newDisplayedQueries = this.filterQueries(this.state.allQueries, this.state.filters, this.state.searchString); + const newDisplayedQueries = this.filterQueries(this.state.allQueries, this.state.filters, this.state.searchString); this.sortAndLimitQueries(newDisplayedQueries, newSortType, newSortOrder, this.state.maxQueries); this.setState({ @@ -418,8 +421,8 @@ var QueryList = React.createClass({ currentSortOrder: newSortOrder }); }, - renderFilterButton: function(filterType, filterText) { - var classNames = "btn btn-sm btn-info style-check"; + renderFilterButton: function (filterType, filterText) { + let classNames = "btn btn-sm btn-info style-check"; if (this.state.filters.indexOf(filterType) > -1) { classNames += " active"; } @@ -428,8 +431,8 @@ var QueryList = React.createClass({ ); }, - handleFilterClick: function(filter) { - var newFilters = this.state.filters.slice(); + handleFilterClick: function (filter) { + const newFilters = this.state.filters.slice(); if (this.state.filters.indexOf(filter) > -1) { newFilters.splice(newFilters.indexOf(filter), 1); } @@ -437,7 +440,7 @@ var QueryList = React.createClass({ newFilters.push(filter); } - var filteredQueries = this.filterQueries(this.state.allQueries, newFilters, this.state.searchString); + const filteredQueries = this.filterQueries(this.state.allQueries, newFilters, this.state.searchString); this.sortAndLimitQueries(filteredQueries, this.state.currentSortType, this.state.currentSortOrder); this.setState({ @@ -445,12 +448,12 @@ var QueryList = React.createClass({ displayedQueries: filteredQueries }); }, - render: function() { - var queryList = ; - if (this.state.displayedQueries == null || this.state.displayedQueries.length == 0) { - var label = (
    Loading...
    ); + render: function () { + let queryList = ; + if (this.state.displayedQueries === null || this.state.displayedQueries.length === 0) { + let label = (
    Loading...
    ); if (this.state.initialized) { - if (this.state.allQueries == null || this.state.allQueries.length == 0) { + if (this.state.allQueries === null || this.state.allQueries.length === 0) { label = "No queries"; } else { @@ -458,76 +461,77 @@ var QueryList = React.createClass({ } } queryList = ( -
    -

    { label }

    -
    +
    +

    { label }

    +
    ); } return ( -
    -
    -
    -
    - - Filter: -
    - { this.renderFilterButton(FILTER_TYPE.RUNNING_BLOCKED, "Running/blocked") } - { this.renderFilterButton(FILTER_TYPE.QUEUED, "Queued") } - { this.renderFilterButton(FILTER_TYPE.FINISHED, "Finished") } - { this.renderFilterButton(FILTER_TYPE.FAILED, "Failed") } - { this.renderFilterButton(FILTER_TYPE.USER_ERROR, "User error") } -
    -   -
    - -
      - { this.renderSortListItem(SORT_TYPE.CREATED, "Creation Time") } - { this.renderSortListItem(SORT_TYPE.ELAPSED, "Elapsed Time") } - { this.renderSortListItem(SORT_TYPE.CPU, "CPU Time") } - { this.renderSortListItem(SORT_TYPE.EXECUTION, "Execution Time") } - { this.renderSortListItem(SORT_TYPE.CURRENT_MEMORY, "Current Memory") } - { this.renderSortListItem(SORT_TYPE.CUMULATIVE_MEMORY, "Cumulative Memory") } -
    -
    -   -
    - -
      - { this.renderReorderListItem(1000, "1s") } - { this.renderReorderListItem(5000, "5s") } - { this.renderReorderListItem(10000, "10s") } - { this.renderReorderListItem(30000, "30s") } -
    • - { this.renderReorderListItem(0, "Off") } -
    -
    -   -
    - -
      - { this.renderMaxQueriesListItem(20, "20 queries") } - { this.renderMaxQueriesListItem(50, "50 queries") } - { this.renderMaxQueriesListItem(100, "100 queries") } -
    • - { this.renderMaxQueriesListItem(0, "All queries") } -
    -
    +
    +
    +
    +
    + + Filter: +
    + { this.renderFilterButton(FILTER_TYPE.RUNNING_BLOCKED, "Running/blocked") } + { this.renderFilterButton(FILTER_TYPE.QUEUED, "Queued") } + { this.renderFilterButton(FILTER_TYPE.FINISHED, "Finished") } + { this.renderFilterButton(FILTER_TYPE.FAILED, "Failed") } + { this.renderFilterButton(FILTER_TYPE.USER_ERROR, "User error") } +
    +   +
    + +
      + { this.renderSortListItem(SORT_TYPE.CREATED, "Creation Time") } + { this.renderSortListItem(SORT_TYPE.ELAPSED, "Elapsed Time") } + { this.renderSortListItem(SORT_TYPE.CPU, "CPU Time") } + { this.renderSortListItem(SORT_TYPE.EXECUTION, "Execution Time") } + { this.renderSortListItem(SORT_TYPE.CURRENT_MEMORY, "Current Memory") } + { this.renderSortListItem(SORT_TYPE.CUMULATIVE_MEMORY, "Cumulative Memory") } +
    +
    +   +
    + +
      + { this.renderReorderListItem(1000, "1s") } + { this.renderReorderListItem(5000, "5s") } + { this.renderReorderListItem(10000, "10s") } + { this.renderReorderListItem(30000, "30s") } +
    • + { this.renderReorderListItem(0, "Off") } +
    +
    +   +
    + +
      + { this.renderMaxQueriesListItem(20, "20 queries") } + { this.renderMaxQueriesListItem(50, "50 queries") } + { this.renderMaxQueriesListItem(100, "100 queries") } +
    • + { this.renderMaxQueriesListItem(0, "All queries") } +
    - { queryList }
    + { queryList } +
    ); - } }); + ReactDOM.render( , document.getElementById('query-list') diff --git a/presto-main/src/main/resources/webapp/assets/query.js b/presto-main/src/main/resources/webapp/assets/query.js index 2ba31d217893e..ba67c09dd8699 100644 --- a/presto-main/src/main/resources/webapp/assets/query.js +++ b/presto-main/src/main/resources/webapp/assets/query.js @@ -29,7 +29,7 @@ let TaskList = React.createClass({ for (let i = 0; i < taskIdArrA.length; i++) { const anum = Number.parseInt(taskIdArrA[i]); const bnum = Number.parseInt(taskIdArrB[i]); - if (anum != bnum) return anum > bnum ? 1 : -1; + if (anum !== bnum) return anum > bnum ? 1 : -1; } return 0; @@ -41,7 +41,7 @@ let TaskList = React.createClass({ const taskUri = tasks[i].taskStatus.self; const hostname = getHostname(taskUri); const port = getPort(taskUri); - if ((hostname in hostToPortNumber) && (hostToPortNumber[hostname] != port)) { + if ((hostname in hostToPortNumber) && (hostToPortNumber[hostname] !== port)) { return true; } hostToPortNumber[hostname] = port; @@ -52,7 +52,7 @@ let TaskList = React.createClass({ render: function() { const tasks = this.props.tasks; - if (tasks === undefined || tasks.length == 0) { + if (tasks === undefined || tasks.length === 0) { return (
    @@ -66,7 +66,7 @@ let TaskList = React.createClass({ const renderedTasks = tasks.map(task => { let elapsedTime = parseDuration(task.stats.elapsedTime); - if (elapsedTime == 0) { + if (elapsedTime === 0) { elapsedTime = Date.now() - Date.parse(task.stats.createTime); } @@ -101,6 +101,9 @@ let TaskList = React.createClass({

    + @@ -128,6 +131,7 @@ let TaskList = React.createClass({ 'state', 'splitsPending', 'splitsRunning', + 'splitsBlocked', 'splitsDone', 'rows', 'rowsSec', @@ -142,9 +146,10 @@ let TaskList = React.createClass({ - - - + + + + @@ -211,7 +216,7 @@ let StageDetail = React.createClass({ const bucketSize = (dataMax - dataMin) / numBuckets; let histogramData = []; - if (bucketSize == 0) { + if (bucketSize === 0) { histogramData = [inputData.length]; } else { @@ -245,7 +250,7 @@ let StageDetail = React.createClass({ const cpuTimes = stage.tasks.map(task => parseDuration(task.stats.totalCpuTime)); // prevent multiple calls to componentDidUpdate (resulting from calls to setState or otherwise) within the refresh interval from re-rendering sparklines/charts - if (this.state.lastRender == null || (Date.now() - this.state.lastRender) >= 1000) { + if (this.state.lastRender === null || (Date.now() - this.state.lastRender) >= 1000) { const renderTimestamp = Date.now(); const stageId = getStageId(stage.stageId); @@ -402,7 +407,7 @@ let StageDetail = React.createClass({ Pending @@ -410,19 +415,15 @@ let StageDetail = React.createClass({ Running @@ -448,7 +449,7 @@ let StageDetail = React.createClass({ @@ -466,7 +467,7 @@ let StageDetail = React.createClass({ @@ -474,7 +475,7 @@ let StageDetail = React.createClass({ @@ -487,7 +488,7 @@ let StageDetail = React.createClass({ Task Scheduled Time @@ -503,7 +504,7 @@ let StageDetail = React.createClass({ Task CPU Time @@ -528,7 +529,7 @@ let StageList = React.createClass({ render: function() { const stages = this.getStages(this.props.outputStage); - if (stages === undefined || stages.length == 0) { + if (stages === undefined || stages.length === 0) { return (
    @@ -606,13 +607,14 @@ let QueryDetail = React.createClass({ resetTimer: function() { clearTimeout(this.timeoutId); // stop refreshing when query finishes or fails - if (this.state.query == null || !this.state.ended) { - this.timeoutId = setTimeout(this.refreshLoop, 1000); + if (this.state.query === null || !this.state.ended) { + // task.info-update-interval is set to 3 seconds by default + this.timeoutId = setTimeout(this.refreshLoop, 3000); } }, refreshLoop: function() { clearTimeout(this.timeoutId); // to stop multiple series of refreshLoop from going on simultaneously - const queryId = window.location.search.substring(1); + const queryId = getFirstParameter(window.location.search); $.get('/v1/query/' + queryId, function (query) { let lastSnapshotStages = this.state.lastSnapshotStage; if (this.state.stageRefresh) { @@ -648,19 +650,19 @@ let QueryDetail = React.createClass({ }); // i.e. don't show sparklines if we've already decided not to update or if we don't have one previous measurement - if (alreadyEnded || (lastRefresh == null && query.state == "RUNNING")) { + if (alreadyEnded || (lastRefresh === null && query.state === "RUNNING")) { this.resetTimer(); return; } - if (lastRefresh == null) { + if (lastRefresh === null) { lastRefresh = nowMillis - parseDuration(query.queryStats.elapsedTime); } - const elapsedSecsSinceLastRefresh = (nowMillis - lastRefresh) / 1000; - if (elapsedSecsSinceLastRefresh != 0) { - const currentScheduledTimeRate = (parseDuration(query.queryStats.totalScheduledTime) - lastScheduledTime) / elapsedSecsSinceLastRefresh; - const currentCpuTimeRate = (parseDuration(query.queryStats.totalCpuTime) - lastCpuTime) / elapsedSecsSinceLastRefresh; + const elapsedSecsSinceLastRefresh = (nowMillis - lastRefresh) / 1000.0; + if (elapsedSecsSinceLastRefresh >= 0) { + const currentScheduledTimeRate = (parseDuration(query.queryStats.totalScheduledTime) - lastScheduledTime) / (elapsedSecsSinceLastRefresh * 1000); + const currentCpuTimeRate = (parseDuration(query.queryStats.totalCpuTime) - lastCpuTime) / (elapsedSecsSinceLastRefresh * 1000); const currentRowInputRate = (query.queryStats.processedInputPositions - lastRowInput) / elapsedSecsSinceLastRefresh; const currentByteInputRate = (parseDataSize(query.queryStats.processedInputDataSize) - lastByteInput) / elapsedSecsSinceLastRefresh; this.setState({ @@ -724,7 +726,7 @@ let QueryDetail = React.createClass({ }, renderTaskFilterListItem: function(taskFilter, taskFilterText) { return ( -
  • { taskFilterText }
  • +
  • { taskFilterText }
  • ); }, handleTaskFilterClick: function(filter, event) { @@ -745,7 +747,7 @@ let QueryDetail = React.createClass({ }, componentDidUpdate: function() { // prevent multiple calls to componentDidUpdate (resulting from calls to setState or otherwise) within the refresh interval from re-rendering sparklines/charts - if (this.state.lastRender == null || (Date.now() - this.state.lastRender) >= 1000) { + if (this.state.lastRender === null || (Date.now() - this.state.lastRender) >= 1000) { const renderTimestamp = Date.now(); $('#scheduled-time-rate-sparkline').sparkline(this.state.scheduledTimeRate, $.extend({}, SMALL_SPARKLINE_PROPERTIES, {chartRangeMin: 0, numberFormatter: precisionRound})); $('#cpu-time-rate-sparkline').sparkline(this.state.cpuTimeRate, $.extend({}, SMALL_SPARKLINE_PROPERTIES, {chartRangeMin: 0, numberFormatter: precisionRound})); @@ -753,7 +755,7 @@ let QueryDetail = React.createClass({ $('#byte-input-rate-sparkline').sparkline(this.state.byteInputRate, $.extend({}, SMALL_SPARKLINE_PROPERTIES, {numberFormatter: formatDataSize})); $('#reserved-memory-sparkline').sparkline(this.state.reservedMemory, $.extend({}, SMALL_SPARKLINE_PROPERTIES, {numberFormatter: formatDataSize})); - if (this.state.lastRender == null) { + if (this.state.lastRender === null) { $('#query').each((i, block) => { hljs.highlightBlock(block); }); @@ -768,7 +770,7 @@ let QueryDetail = React.createClass({ new Clipboard('.copy-button'); }, renderTasks: function() { - if (this.state.lastSnapshotTasks == null) { + if (this.state.lastSnapshotTasks === null) { return; } @@ -786,7 +788,7 @@ let QueryDetail = React.createClass({
    - diff --git a/presto-main/src/main/resources/webapp/assets/stage.js b/presto-main/src/main/resources/webapp/assets/stage.js index 68e8b0ead73c0..142fe8b07df9e 100644 --- a/presto-main/src/main/resources/webapp/assets/stage.js +++ b/presto-main/src/main/resources/webapp/assets/stage.js @@ -360,7 +360,7 @@ let StageOperatorGraph = React.createClass({ }); let nodeOperators = operatorMap.get(planNode.id); - if (!nodeOperators || nodeOperators.length == 0) { + if (!nodeOperators || nodeOperators.length === 0) { return sourceResults; } @@ -427,7 +427,7 @@ let StageOperatorGraph = React.createClass({ this.computeD3StageOperatorGraph(graph, operator.child, operatorNodeId, pipelineNode); } - if (sink != null) { + if (sink !== null) { graph.setEdge(operatorNodeId, sink, {class: "edge-class", arrowheadStyle: "stroke-width: 0; fill: #fff;"}); } @@ -476,7 +476,7 @@ let StageOperatorGraph = React.createClass({ ); } - if (!stage.hasOwnProperty('stageStats') || !stage.stageStats.hasOwnProperty("operatorSummaries") || stage.stageStats.operatorSummaries.length == 0) { + if (!stage.hasOwnProperty('stageStats') || !stage.stageStats.hasOwnProperty("operatorSummaries") || stage.stageStats.operatorSummaries.length === 0) { return (
    @@ -506,7 +506,7 @@ let StagePerformance = React.createClass({ resetTimer: function() { clearTimeout(this.timeoutId); // stop refreshing when query finishes or fails - if (this.state.query == null || !this.state.ended) { + if (this.state.query === null || !this.state.ended) { this.timeoutId = setTimeout(this.refreshLoop, 1000); } }, @@ -547,7 +547,7 @@ let StagePerformance = React.createClass({ }, refreshLoop: function() { clearTimeout(this.timeoutId); // to stop multiple series of refreshLoop from going on simultaneously - const queryString = window.location.search.substring(1).split('.'); + const queryString = getFirstParameter(window.location.search).split('.'); const queryId = queryString[0]; let selectedStageId = this.state.selectedStageId; @@ -578,7 +578,7 @@ let StagePerformance = React.createClass({ this.refreshLoop(); }, findStage: function (stageId, currentStage) { - if (stageId == null) { + if (stageId === null) { return null; } @@ -588,7 +588,7 @@ let StagePerformance = React.createClass({ for (let i = 0; i < currentStage.subStages.length; i++) { const stage = this.findStage(stageId, currentStage.subStages[i]); - if (stage != null) { + if (stage !== null) { return stage; } } @@ -627,7 +627,7 @@ let StagePerformance = React.createClass({ this.getAllStageIds(allStages, query.outputStage); const stage = this.findStage(query.queryId + "." + this.state.selectedStageId, query.outputStage); - if (stage == null) { + if (stage === null) { return (

    Stage not found

    @@ -693,7 +693,7 @@ let StagePerformance = React.createClass({

    Stage { stage.plan.id }

    -
    +
    { task.stats.runningDrivers } + { task.stats.blockedDrivers } + { task.stats.completedDrivers } ID Host State Rows Rows/s Bytes - { stage.tasks.filter(task => task.taskStatus.state == "PLANNED").length } + { stage.tasks.filter(task => task.taskStatus.state === "PLANNED").length }
    - { stage.tasks.filter(task => task.taskStatus.state == "RUNNING").length } + { stage.tasks.filter(task => task.taskStatus.state === "RUNNING").length }
    - Finished + Blocked - { stage.tasks.filter(task => { - return task.taskStatus.state == "FINISHED" || - task.taskStatus.state == "CANCELED" || - task.taskStatus.state == "ABORTED" || - task.taskStatus.state == "FAILED" }).length } + { stage.tasks.filter(task => task.stats.fullyBlocked).length }
    -
    +
    -
    +
    - +
    -
    +
    -
    +
      { this.renderTaskFilterListItem(TASK_FILTER.ALL, "All") } @@ -811,7 +813,7 @@ let QueryDetail = React.createClass({ ); }, renderStages: function() { - if (this.state.lastSnapshotStage == null) { + if (this.state.lastSnapshotStage === null) { return; } @@ -952,7 +954,7 @@ let QueryDetail = React.createClass({ render: function() { const query = this.state.query; - if (query == null || this.state.initialized == false) { + if (query === null || this.state.initialized === false) { let label = (
      Loading...
      ); if (this.state.initialized) { label = "Query not found"; @@ -1119,7 +1121,7 @@ let QueryDetail = React.createClass({
    Resource Group + { query.resourceGroupName }
    - - Query - - Live Plan - - + Cluster Overview
    @@ -86,14 +81,20 @@ @@ -111,10 +112,7 @@ diff --git a/presto-main/src/main/resources/webapp/query.html b/presto-main/src/main/resources/webapp/query.html index 429775740e0f2..ec24e271efa5f 100644 --- a/presto-main/src/main/resources/webapp/query.html +++ b/presto-main/src/main/resources/webapp/query.html @@ -69,12 +69,7 @@ - - Query - - Overview - - + Cluster Overview @@ -83,14 +78,20 @@ @@ -104,10 +105,7 @@ diff --git a/presto-main/src/main/resources/webapp/stage.html b/presto-main/src/main/resources/webapp/stage.html index add8c9d13db63..09a7d86a796ff 100644 --- a/presto-main/src/main/resources/webapp/stage.html +++ b/presto-main/src/main/resources/webapp/stage.html @@ -72,10 +72,7 @@ - - Query - Stage Performance - + Cluster Overview @@ -84,14 +81,20 @@ @@ -116,10 +119,7 @@ diff --git a/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java b/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java index 9f7f31f172d45..5945f04514b6e 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/AbstractTestBlock.java @@ -21,11 +21,16 @@ import com.google.common.primitives.Ints; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; +import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; +import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; import java.lang.reflect.Array; +import java.lang.reflect.Field; +import java.util.IdentityHashMap; import java.util.List; +import java.util.Map; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; @@ -34,6 +39,7 @@ import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.slice.SizeOf.SIZE_OF_SHORT; +import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.toIntExact; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -49,6 +55,7 @@ protected void assertBlock(Block block, T[] expectedValues) assertBlockPositions(copyBlock(block), expectedValues); assertBlockSize(block); + assertRetainedSize(block); try { block.isNull(-1); @@ -64,6 +71,80 @@ protected void assertBlock(Block block, T[] expectedValues) } } + // copied from SliceArrayBlock, any changes should be reflected + private static long getSliceArrayRetainedSizeInBytes(Slice[] values) + { + long sizeInBytes = sizeOf(values); + Map uniqueRetained = new IdentityHashMap<>(values.length); + for (Slice value : values) { + if (value == null) { + continue; + } + if (value.getBase() != null && uniqueRetained.put(value.getBase(), true) == null) { + sizeInBytes += value.getRetainedSize(); + } + } + return sizeInBytes; + } + + private void assertRetainedSize(Block block) + { + long retainedSize = ClassLayout.parseClass(block.getClass()).instanceSize(); + Field[] fields = block.getClass().getDeclaredFields(); + try { + for (Field field : fields) { + Class type = field.getType(); + if (type.isPrimitive()) { + continue; + } + + field.setAccessible(true); + + if (type.equals(Slice.class)) { + retainedSize += ((Slice) field.get(block)).getRetainedSize(); + } + else if (type.equals(BlockBuilderStatus.class)) { + retainedSize += BlockBuilderStatus.INSTANCE_SIZE; + } + else if (type.equals(BlockBuilder.class) || type.equals(Block.class)) { + retainedSize += ((Block) field.get(block)).getRetainedSizeInBytes(); + } + else if (type.equals(Slice[].class)) { + retainedSize += getSliceArrayRetainedSizeInBytes((Slice[]) field.get(block)); + } + else if (type.equals(BlockBuilder[].class) || type.equals(Block[].class)) { + Block[] blocks = (Block[]) field.get(block); + for (Block innerBlock : blocks) { + assertRetainedSize(innerBlock); + retainedSize += innerBlock.getRetainedSizeInBytes(); + } + } + else if (type.equals(SliceOutput.class)) { + retainedSize += ((SliceOutput) field.get(block)).getRetainedSize(); + } + else if (type.equals(int[].class)) { + retainedSize += sizeOf((int[]) field.get(block)); + } + else if (type.equals(boolean[].class)) { + retainedSize += sizeOf((boolean[]) field.get(block)); + } + else if (type.equals(byte[].class)) { + retainedSize += sizeOf((byte[]) field.get(block)); + } + else if (type.equals(long[].class)) { + retainedSize += sizeOf((long[]) field.get(block)); + } + else if (type.equals(short[].class)) { + retainedSize += sizeOf((short[]) field.get(block)); + } + } + } + catch (IllegalAccessException | IllegalArgumentException t) { + throw new RuntimeException(t); + } + assertEquals(block.getRetainedSizeInBytes(), retainedSize); + } + protected void assertBlockFilteredPositions(T[] expectedValues, Block block, List positions) { Block filteredBlock = block.copyPositions(positions); @@ -106,17 +187,17 @@ private void assertBlockSize(Block block) { // Asserting on `block` is not very effective because most blocks passed to this method is compact. // Therefore, we split the `block` into two and assert again. - int expectedBlockSize = copyBlock(block).getSizeInBytes(); + long expectedBlockSize = copyBlock(block).getSizeInBytes(); assertEquals(block.getSizeInBytes(), expectedBlockSize); assertEquals(block.getRegionSizeInBytes(0, block.getPositionCount()), expectedBlockSize); List splitBlock = splitBlock(block, 2); Block firstHalf = splitBlock.get(0); - int expectedFirstHalfSize = copyBlock(firstHalf).getSizeInBytes(); + long expectedFirstHalfSize = copyBlock(firstHalf).getSizeInBytes(); assertEquals(firstHalf.getSizeInBytes(), expectedFirstHalfSize); assertEquals(block.getRegionSizeInBytes(0, firstHalf.getPositionCount()), expectedFirstHalfSize); Block secondHalf = splitBlock.get(1); - int expectedSecondHalfSize = copyBlock(secondHalf).getSizeInBytes(); + long expectedSecondHalfSize = copyBlock(secondHalf).getSizeInBytes(); assertEquals(secondHalf.getSizeInBytes(), expectedSecondHalfSize); assertEquals(block.getRegionSizeInBytes(firstHalf.getPositionCount(), secondHalf.getPositionCount()), expectedSecondHalfSize); } @@ -312,4 +393,13 @@ protected static Object[] alternatingNullValues(Object[] objects) objectsWithNulls[objectsWithNulls.length - 1] = null; return objectsWithNulls; } + + protected static Slice[] createExpectedUniqueValues(int positionCount) + { + Slice[] expectedValues = new Slice[positionCount]; + for (int position = 0; position < positionCount; position++) { + expectedValues[position] = Slices.copyOf(createExpectedValue(position)); + } + return expectedValues; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java b/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java index 19fa3f9b4f73f..c833bf51248d1 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/block/BlockAssertions.java @@ -18,9 +18,9 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.RunLengthEncodedBlock; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import io.airlift.slice.Slice; import java.math.BigDecimal; @@ -158,7 +158,7 @@ public static Block createStringDictionaryBlock(int start, int length) for (int i = 0; i < length; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(length, builder.build(), ids); + return new DictionaryBlock(builder.build(), ids); } public static Block createStringArraysBlock(Iterable> values) @@ -346,7 +346,7 @@ public static Block createLongDictionaryBlock(int start, int length) for (int i = 0; i < length; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(length, builder.build(), ids); + return new DictionaryBlock(builder.build(), ids); } public static Block createLongRepeatBlock(int value, int length) diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java b/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java index c6564a659c995..9c582c1d8335c 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestBlockBuilder.java @@ -17,8 +17,8 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java index 25f4dc6d0eb4e..fcc986685204d 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java @@ -62,8 +62,8 @@ public void testCopyPositionsWithCompaction() assertEquals(copiedBlock.getDictionary().getPositionCount(), 1); assertEquals(copiedBlock.getPositionCount(), positionsToCopy.size()); - assertBlock(copiedBlock.getDictionary(), new Slice[]{firstExpectedValue}); - assertBlock(copiedBlock, new Slice[]{firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue}); + assertBlock(copiedBlock.getDictionary(), new Slice[] {firstExpectedValue}); + assertBlock(copiedBlock, new Slice[] {firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue, firstExpectedValue}); } @Test @@ -79,7 +79,7 @@ public void testCopyPositionsWithCompactionsAndReorder() assertEquals(copiedBlock.getDictionary().getPositionCount(), 2); assertEquals(copiedBlock.getPositionCount(), positionsToCopy.size()); - assertBlock(copiedBlock.getDictionary(), new Slice[] { expectedValues[0], expectedValues[5] }); + assertBlock(copiedBlock.getDictionary(), new Slice[] {expectedValues[0], expectedValues[5]}); assertDictionaryIds(copiedBlock, 0, 1, 0, 1, 0); } @@ -96,7 +96,7 @@ public void testCopyPositionsSamePosition() assertEquals(copiedBlock.getDictionary().getPositionCount(), 1); assertEquals(copiedBlock.getPositionCount(), positionsToCopy.size()); - assertBlock(copiedBlock.getDictionary(), new Slice[] { expectedValues[2] }); + assertBlock(copiedBlock.getDictionary(), new Slice[] {expectedValues[2]}); assertDictionaryIds(copiedBlock, 0, 0, 0); } @@ -126,7 +126,7 @@ public void testCompact() assertNotEquals(dictionaryBlock.getDictionarySourceId(), compactBlock.getDictionarySourceId()); assertEquals(compactBlock.getDictionary().getPositionCount(), (expectedValues.length / 2) + 1); - assertBlock(compactBlock.getDictionary(), new Slice[] { expectedValues[0], expectedValues[1], expectedValues[3] }); + assertBlock(compactBlock.getDictionary(), new Slice[] {expectedValues[0], expectedValues[1], expectedValues[3]}); assertDictionaryIds(compactBlock, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2); assertEquals(compactBlock.isCompact(), true); @@ -164,7 +164,7 @@ private static DictionaryBlock createDictionaryBlockWithUnreferencedKeys(Slice[] } ids[i] = index; } - return new DictionaryBlock(positionCount, new SliceArrayBlock(dictionarySize, expectedValues), ids); + return new DictionaryBlock(new SliceArrayBlock(dictionarySize, expectedValues), ids); } private static DictionaryBlock createDictionaryBlock(Slice[] expectedValues, int positionCount) @@ -175,7 +175,7 @@ private static DictionaryBlock createDictionaryBlock(Slice[] expectedValues, int for (int i = 0; i < positionCount; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(positionCount, new SliceArrayBlock(dictionarySize, expectedValues), ids); + return new DictionaryBlock(new SliceArrayBlock(dictionarySize, expectedValues), ids); } private static void assertDictionaryIds(DictionaryBlock dictionaryBlock, int... expected) diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java index f02332919f758..53c83b6b06fa7 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestInterleavedBlock.java @@ -80,11 +80,11 @@ private void testGetSizeInBytes() InterleavedBlock block = blockBuilder.build(); List splitQuarter = splitBlock(block, 4); - int sizeInBytes = block.getSizeInBytes(); - int quarter1size = splitQuarter.get(0).getSizeInBytes(); - int quarter2size = splitQuarter.get(1).getSizeInBytes(); - int quarter3size = splitQuarter.get(2).getSizeInBytes(); - int quarter4size = splitQuarter.get(3).getSizeInBytes(); + long sizeInBytes = block.getSizeInBytes(); + long quarter1size = splitQuarter.get(0).getSizeInBytes(); + long quarter2size = splitQuarter.get(1).getSizeInBytes(); + long quarter3size = splitQuarter.get(2).getSizeInBytes(); + long quarter4size = splitQuarter.get(3).getSizeInBytes(); double expectedQuarterSizeMin = sizeInBytes * 0.2; double expectedQuarterSizeMax = sizeInBytes * 0.3; assertTrue(quarter1size > expectedQuarterSizeMin && quarter1size < expectedQuarterSizeMax, format("quarter1size is %s, should be between %s and %s", quarter1size, expectedQuarterSizeMin, expectedQuarterSizeMax)); diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestMapBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestMapBlock.java new file mode 100644 index 0000000000000..3b49f65881481 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/block/TestMapBlock.java @@ -0,0 +1,204 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.block; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.block.SingleMapBlock; +import com.facebook.presto.spi.type.MapType; +import com.google.common.primitives.Ints; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.block.BlockAssertions.createLongsBlock; +import static com.facebook.presto.block.BlockAssertions.createStringsBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.util.StructuralTestUtil.mapType; +import static io.airlift.slice.Slices.utf8Slice; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +public class TestMapBlock + extends AbstractTestBlock +{ + @Test + public void test() + { + testWith(createTestMap(9, 3, 4, 0, 8, 0, 6, 5)); + } + + private Map[] createTestMap(int... entryCounts) + { + Map[] result = new Map[entryCounts.length]; + for (int rowNumber = 0; rowNumber < entryCounts.length; rowNumber++) { + int entryCount = entryCounts[rowNumber]; + Map map = new HashMap<>(); + for (int entryNumber = 0; entryNumber < entryCount; entryNumber++) { + map.put("key" + entryNumber, entryNumber == 5 ? null : rowNumber * 100L + entryNumber); + } + result[rowNumber] = map; + } + return result; + } + + private void testWith(Map[] expectedValues) + { + BlockBuilder blockBuilder = createBlockBuilderWithValues(expectedValues); + + assertBlock(blockBuilder, expectedValues); + assertBlock(blockBuilder.build(), expectedValues); + assertBlockFilteredPositions(expectedValues, blockBuilder, Ints.asList(0, 1, 3, 4, 7)); + assertBlockFilteredPositions(expectedValues, blockBuilder.build(), Ints.asList(0, 1, 3, 4, 7)); + assertBlockFilteredPositions(expectedValues, blockBuilder, Ints.asList(2, 3, 5, 6)); + assertBlockFilteredPositions(expectedValues, blockBuilder.build(), Ints.asList(2, 3, 5, 6)); + + Block block = createBlockWithValuesFromKeyValueBlock(expectedValues); + + assertBlock(block, expectedValues); + assertBlockFilteredPositions(expectedValues, block, Ints.asList(0, 1, 3, 4, 7)); + assertBlockFilteredPositions(expectedValues, block, Ints.asList(2, 3, 5, 6)); + + Map[] expectedValuesWithNull = (Map[]) alternatingNullValues(expectedValues); + BlockBuilder blockBuilderWithNull = createBlockBuilderWithValues(expectedValuesWithNull); + + assertBlock(blockBuilderWithNull, expectedValuesWithNull); + assertBlock(blockBuilderWithNull.build(), expectedValuesWithNull); + assertBlockFilteredPositions(expectedValuesWithNull, blockBuilderWithNull, Ints.asList(0, 1, 5, 6, 7, 10, 11, 12, 15)); + assertBlockFilteredPositions(expectedValuesWithNull, blockBuilderWithNull.build(), Ints.asList(0, 1, 5, 6, 7, 10, 11, 12, 15)); + assertBlockFilteredPositions(expectedValuesWithNull, blockBuilderWithNull, Ints.asList(2, 3, 4, 9, 13, 14)); + assertBlockFilteredPositions(expectedValuesWithNull, blockBuilderWithNull.build(), Ints.asList(2, 3, 4, 9, 13, 14)); + + Block blockWithNull = createBlockWithValuesFromKeyValueBlock(expectedValuesWithNull); + + assertBlock(blockWithNull, expectedValuesWithNull); + assertBlockFilteredPositions(expectedValuesWithNull, blockWithNull, Ints.asList(0, 1, 5, 6, 7, 10, 11, 12, 15)); + assertBlockFilteredPositions(expectedValuesWithNull, blockWithNull, Ints.asList(2, 3, 4, 9, 13, 14)); + } + + private BlockBuilder createBlockBuilderWithValues(Map[] maps) + { + MapType mapType = mapType(VARCHAR, BIGINT); + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + for (Map map : maps) { + createBlockBuilderWithValues(map, mapBlockBuilder); + } + return mapBlockBuilder; + } + + private Block createBlockWithValuesFromKeyValueBlock(Map[] maps) + { + List keys = new ArrayList<>(); + List values = new ArrayList<>(); + int[] offsets = new int[maps.length + 1]; + boolean[] mapIsNull = new boolean[maps.length]; + for (int i = 0; i < maps.length; i++) { + Map map = maps[i]; + mapIsNull[i] = map == null; + if (map == null) { + offsets[i + 1] = offsets[i]; + } + else { + for (Map.Entry entry : map.entrySet()) { + keys.add(entry.getKey()); + values.add(entry.getValue()); + } + offsets[i + 1] = offsets[i] + map.size(); + } + } + return mapType(VARCHAR, BIGINT).createBlockFromKeyValue(mapIsNull, offsets, createStringsBlock(keys), createLongsBlock(values)); + } + + private void createBlockBuilderWithValues(Map map, BlockBuilder mapBlockBuilder) + { + if (map == null) { + mapBlockBuilder.appendNull(); + } + else { + BlockBuilder elementBlockBuilder = mapBlockBuilder.beginBlockEntry(); + for (Map.Entry entry : map.entrySet()) { + VARCHAR.writeSlice(elementBlockBuilder, utf8Slice(entry.getKey())); + if (entry.getValue() == null) { + elementBlockBuilder.appendNull(); + } + else { + BIGINT.writeLong(elementBlockBuilder, entry.getValue()); + } + } + mapBlockBuilder.closeEntry(); + } + } + + @Override + protected void assertPositionValue(Block block, int position, T expectedValue) + { + if (expectedValue instanceof Map) { + assertValue(block, position, (Map) expectedValue); + return; + } + super.assertPositionValue(block, position, expectedValue); + } + + private void assertValue(Block mapBlock, int position, Map map) + { + MapType mapType = mapType(VARCHAR, BIGINT); + + // null maps are handled by assertPositionValue + requireNonNull(map, "map is null"); + + assertFalse(mapBlock.isNull(position)); + SingleMapBlock elementBlock = (SingleMapBlock) mapType.getObject(mapBlock, position); + assertEquals(elementBlock.getPositionCount(), map.size() * 2); + + // Test new/hash-index access: assert inserted keys + for (Map.Entry entry : map.entrySet()) { + int pos = elementBlock.seekKey(utf8Slice(entry.getKey())); + assertNotEquals(pos, -1); + if (entry.getValue() == null) { + assertTrue(elementBlock.isNull(pos)); + } + else { + assertFalse(elementBlock.isNull(pos)); + assertEquals(BIGINT.getLong(elementBlock, pos), (long) entry.getValue()); + } + } + // Test new/hash-index access: assert non-existent keys + for (int i = 0; i < 10; i++) { + assertEquals(elementBlock.seekKey(utf8Slice("not-inserted-" + i)), -1); + } + + // Test legacy/iterative access + for (int i = 0; i < elementBlock.getPositionCount(); i += 2) { + String actualKey = VARCHAR.getSlice(elementBlock, i).toStringUtf8(); + Long actualValue; + if (elementBlock.isNull(i + 1)) { + actualValue = null; + } + else { + actualValue = BIGINT.getLong(elementBlock, i + 1); + } + assertTrue(map.containsKey(actualKey)); + assertEquals(actualValue, map.get(actualKey)); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestSliceArrayBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestSliceArrayBlock.java index 4f399e9d552d5..98c47254299ab 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestSliceArrayBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestSliceArrayBlock.java @@ -25,8 +25,15 @@ public class TestSliceArrayBlock public void test() { Slice[] expectedValues = createExpectedValues(100); - assertVariableWithValues(expectedValues); - assertVariableWithValues((Slice[]) alternatingNullValues(expectedValues)); + assertVariableWithValues(expectedValues, false); + assertVariableWithValues((Slice[]) alternatingNullValues(expectedValues), false); + } + + @Test + public void testDistinctSlices() + { + Slice[] expectedValues = createExpectedUniqueValues(100); + assertVariableWithValues(expectedValues, true); } @Test @@ -38,9 +45,9 @@ public void testCopyPositions() assertBlockFilteredPositions(expectedValues, block, Ints.asList(0, 2, 4, 6, 7, 9, 10, 16)); } - private void assertVariableWithValues(Slice[] expectedValues) + private void assertVariableWithValues(Slice[] expectedValues, boolean valueSlicesAreDistinct) { - SliceArrayBlock block = new SliceArrayBlock(expectedValues.length, expectedValues); + SliceArrayBlock block = new SliceArrayBlock(expectedValues.length, expectedValues, valueSlicesAreDistinct); assertBlock(block, expectedValues); } } diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java index 1507cd65cf16c..d7e4e82dd2576 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestVariableWidthBlock.java @@ -95,11 +95,11 @@ private void testGetSizeInBytes() Block block = blockBuilder.build(); List splitQuarter = splitBlock(block, 4); - int sizeInBytes = block.getSizeInBytes(); - int quarter1size = splitQuarter.get(0).getSizeInBytes(); - int quarter2size = splitQuarter.get(1).getSizeInBytes(); - int quarter3size = splitQuarter.get(2).getSizeInBytes(); - int quarter4size = splitQuarter.get(3).getSizeInBytes(); + long sizeInBytes = block.getSizeInBytes(); + long quarter1size = splitQuarter.get(0).getSizeInBytes(); + long quarter2size = splitQuarter.get(1).getSizeInBytes(); + long quarter3size = splitQuarter.get(2).getSizeInBytes(); + long quarter4size = splitQuarter.get(3).getSizeInBytes(); double expectedQuarterSizeMin = sizeInBytes * 0.2; double expectedQuarterSizeMax = sizeInBytes * 0.3; assertTrue(quarter1size > expectedQuarterSizeMin && quarter1size < expectedQuarterSizeMax, format("quarter1size is %s, should be between %s and %s", quarter1size, expectedQuarterSizeMin, expectedQuarterSizeMax)); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCoefficientBasedCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCoefficientBasedCostCalculator.java new file mode 100644 index 0000000000000..5641fd9bfa2cf --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCoefficientBasedCostCalculator.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.sql.planner.LogicalPlanner; +import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.assertions.PlanAssert; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestCoefficientBasedCostCalculator +{ + private final LocalQueryRunner queryRunner; + private final CostCalculator costCalculator; + + public TestCoefficientBasedCostCalculator() + { + this.queryRunner = new LocalQueryRunner(testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel + .build()); + + queryRunner.createCatalog( + queryRunner.getDefaultSession().getCatalog().get(), + new TpchConnectorFactory(1, true), + ImmutableMap.of()); + + costCalculator = new CoefficientBasedCostCalculator(queryRunner.getMetadata()); + } + + @Test + public void testCostCalculatorUsesLayout() + { + assertPlan("SELECT orderstatus FROM orders WHERE orderstatus = 'P'", + anyTree( + node(FilterNode.class, + node(TableScanNode.class) + .withCost(PlanNodeCost.builder() + .setOutputRowCount(new Estimate(385.0)) + .setOutputSizeInBytes(unknownValue()) + .build())))); + + assertPlan("SELECT orderstatus FROM orders WHERE orderkey = 42", + anyTree( + node(FilterNode.class, + node(TableScanNode.class) + .withCost(PlanNodeCost.builder() + .setOutputRowCount(new Estimate(0.0)) + .setOutputSizeInBytes(unknownValue()) + .build())))); + } + + private void assertPlan(String sql, PlanMatchPattern pattern) + { + assertPlan(sql, LogicalPlanner.Stage.CREATED, pattern); + } + + private void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern) + { + queryRunner.inTransaction(transactionSession -> { + Plan actualPlan = queryRunner.createPlan(transactionSession, sql, stage); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), costCalculator, actualPlan, pattern); + return null; + }); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java b/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java index a3afc75b62104..b121714cdf7f8 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/MockQueryExecution.java @@ -19,10 +19,13 @@ import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.facebook.presto.sql.planner.Plan; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.units.DataSize; import io.airlift.units.Duration; +import org.joda.time.DateTime; import java.net.URI; import java.util.ArrayList; @@ -35,8 +38,10 @@ import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; public class MockQueryExecution implements QueryExecution @@ -95,7 +100,50 @@ public QueryInfo getQueryInfo() URI.create("http://test"), ImmutableList.of(), "SELECT 1", - new QueryStats(), + new QueryStats( + new DateTime(1), + new DateTime(2), + new DateTime(3), + new DateTime(4), + new Duration(6, NANOSECONDS), + new Duration(5, NANOSECONDS), + new Duration(7, NANOSECONDS), + new Duration(8, NANOSECONDS), + + new Duration(100, NANOSECONDS), + new Duration(200, NANOSECONDS), + + 9, + 10, + 11, + + 12, + 13, + 15, + 30, + 16, + + 17.0, + new DataSize(18, BYTE), + new DataSize(19, BYTE), + + true, + new Duration(20, NANOSECONDS), + new Duration(21, NANOSECONDS), + new Duration(22, NANOSECONDS), + new Duration(23, NANOSECONDS), + false, + ImmutableSet.of(), + + new DataSize(24, BYTE), + 25, + + new DataSize(26, BYTE), + 27, + + new DataSize(28, BYTE), + 29, + ImmutableList.of()), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), @@ -118,6 +166,12 @@ public QueryState getState() return state; } + @Override + public Plan getQueryPlan() + { + throw new UnsupportedOperationException(); + } + public Throwable getFailureCause() { return failureCause; diff --git a/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java b/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java index 77e97c2336ed2..81559284d3861 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java @@ -33,12 +33,12 @@ import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TestingColumnHandle; -import com.facebook.presto.sql.planner.TestingTableHandle; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java index 323d7f146cd51..5bf137c6caad0 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TaskTestUtils.java @@ -18,6 +18,7 @@ import com.facebook.presto.TaskSource; import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.cost.CoefficientBasedCostCalculator; import com.facebook.presto.execution.TestSqlTaskManager.MockExchangeClientSupplier; import com.facebook.presto.execution.scheduler.LegacyNetworkTopology; import com.facebook.presto.execution.scheduler.NodeScheduler; @@ -52,11 +53,11 @@ import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TestingColumnHandle; -import com.facebook.presto.sql.planner.TestingTableHandle; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.facebook.presto.testing.TestingSplit; import com.facebook.presto.testing.TestingTransactionHandle; import com.facebook.presto.util.FinalizerService; @@ -67,6 +68,7 @@ import java.util.Optional; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; @@ -95,7 +97,7 @@ private TaskTestUtils() TABLE_SCAN_NODE_ID, new TableHandle(CONNECTOR_ID, new TestingTableHandle()), ImmutableList.of(SYMBOL), - ImmutableMap.of(SYMBOL, new TestingColumnHandle("column")), + ImmutableMap.of(SYMBOL, new TestingColumnHandle("column", 0, BIGINT)), Optional.empty(), TupleDomain.all(), null), @@ -125,6 +127,7 @@ public static LocalExecutionPlanner createTestingPlanner() return new LocalExecutionPlanner( metadata, new SqlParser(), + new CoefficientBasedCostCalculator(metadata), Optional.empty(), pageSourceManager, new IndexManager(), diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestQueryStats.java b/presto-main/src/test/java/com/facebook/presto/execution/TestQueryStats.java index 31b0cf80619c7..f0f10a539c4bd 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestQueryStats.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestQueryStats.java @@ -48,6 +48,7 @@ public class TestQueryStats 12, 13, 15, + 30, 16, 17.0, @@ -106,6 +107,7 @@ public static void assertExpectedQueryStats(QueryStats actual) assertEquals(actual.getTotalDrivers(), 12); assertEquals(actual.getQueuedDrivers(), 13); assertEquals(actual.getRunningDrivers(), 15); + assertEquals(actual.getBlockedDrivers(), 30); assertEquals(actual.getCompletedDrivers(), 16); assertEquals(actual.getCumulativeMemory(), 17.0); diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestStageStats.java b/presto-main/src/test/java/com/facebook/presto/execution/TestStageStats.java index eb90aab2290f2..75f03300300e1 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestStageStats.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestStageStats.java @@ -43,6 +43,7 @@ public class TestStageStats 7, 8, 10, + 26, 11, 12.0, @@ -93,6 +94,7 @@ public static void assertExpectedStageStats(StageStats actual) assertEquals(actual.getTotalDrivers(), 7); assertEquals(actual.getQueuedDrivers(), 8); assertEquals(actual.getRunningDrivers(), 10); + assertEquals(actual.getBlockedDrivers(), 26); assertEquals(actual.getCompletedDrivers(), 11); assertEquals(actual.getCumulativeMemory(), 12.0); diff --git a/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java b/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java index bec068a2e31a5..7760120bf571c 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java @@ -15,6 +15,8 @@ import com.facebook.presto.execution.MockQueryExecution; import com.facebook.presto.execution.resourceGroups.InternalResourceGroup.RootInternalResourceGroup; +import com.facebook.presto.server.QueryStateInfo; +import com.facebook.presto.server.ResourceGroupStateInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; import com.google.common.collect.ImmutableSet; import io.airlift.units.DataSize; @@ -26,6 +28,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.SortedMap; @@ -34,6 +37,7 @@ import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; +import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_RUN; import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.QUERY_PRIORITY; import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.WEIGHTED; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -42,6 +46,7 @@ import static io.airlift.units.DataSize.Unit.BYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Collections.reverse; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static org.testng.Assert.assertEquals; @@ -470,6 +475,120 @@ public void testGetInfo() assertEquals(info.getNumAggregatedQueuedQueries(), 26); } + @Test + public void testGetResourceGroupStateInfo() + { + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> { }, directExecutor()); + root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + root.setMaxQueuedQueries(40); + root.setMaxRunningQueries(10); + root.setSchedulingPolicy(WEIGHTED); + + InternalResourceGroup rootA = root.getOrCreateSubGroup("a"); + rootA.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootA.setMaxQueuedQueries(20); + rootA.setMaxRunningQueries(0); + + InternalResourceGroup rootB = root.getOrCreateSubGroup("b"); + rootB.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootB.setMaxQueuedQueries(20); + rootB.setMaxRunningQueries(1); + rootB.setSchedulingWeight(2); + rootB.setSchedulingPolicy(QUERY_PRIORITY); + + InternalResourceGroup rootAX = rootA.getOrCreateSubGroup("x"); + rootAX.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootAX.setMaxQueuedQueries(10); + rootAX.setMaxRunningQueries(10); + + InternalResourceGroup rootAY = rootA.getOrCreateSubGroup("y"); + rootAY.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootAY.setMaxQueuedQueries(10); + rootAY.setMaxRunningQueries(10); + + Set queries = fillGroupTo(rootAX, ImmutableSet.of(), 5, false); + queries.addAll(fillGroupTo(rootAY, ImmutableSet.of(), 5, false)); + queries.addAll(fillGroupTo(rootB, ImmutableSet.of(), 10, true)); + + ResourceGroupStateInfo stateInfo = root.getStateInfo(); + assertEquals(stateInfo.getId(), root.getId()); + assertEquals(stateInfo.getState(), CAN_RUN); + assertEquals(stateInfo.getSoftMemoryLimit(), root.getSoftMemoryLimit()); + assertEquals(stateInfo.getMemoryUsage(), new DataSize(0, BYTE)); + assertEquals(stateInfo.getMaxRunningQueries(), root.getMaxRunningQueries()); + assertEquals(stateInfo.getRunningTimeLimit(), new Duration(Long.MAX_VALUE, MILLISECONDS)); + assertEquals(stateInfo.getMaxQueuedQueries(), root.getMaxQueuedQueries()); + assertEquals(stateInfo.getQueuedTimeLimit(), new Duration(Long.MAX_VALUE, MILLISECONDS)); + assertEquals(stateInfo.getNumQueuedQueries(), 19); + assertEquals(stateInfo.getRunningQueries().size(), 1); + QueryStateInfo queryInfo = stateInfo.getRunningQueries().get(0); + assertEquals(queryInfo.getResourceGroupId(), Optional.of(rootB.getId())); + } + + @Test + public void testGetBlockedQueuedQueries() + { + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> { }, directExecutor()); + root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + root.setMaxQueuedQueries(40); + // Start with zero capacity, so that nothing starts running until we've added all the queries + root.setMaxRunningQueries(0); + + InternalResourceGroup rootA = root.getOrCreateSubGroup("a"); + rootA.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootA.setMaxQueuedQueries(20); + rootA.setMaxRunningQueries(8); + + InternalResourceGroup rootAX = rootA.getOrCreateSubGroup("x"); + rootAX.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootAX.setMaxQueuedQueries(10); + rootAX.setMaxRunningQueries(8); + + InternalResourceGroup rootAY = rootA.getOrCreateSubGroup("y"); + rootAY.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootAY.setMaxQueuedQueries(10); + rootAY.setMaxRunningQueries(5); + + InternalResourceGroup rootB = root.getOrCreateSubGroup("b"); + rootB.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootB.setMaxQueuedQueries(20); + rootB.setMaxRunningQueries(8); + + InternalResourceGroup rootBX = rootB.getOrCreateSubGroup("x"); + rootBX.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootBX.setMaxQueuedQueries(10); + rootBX.setMaxRunningQueries(8); + + InternalResourceGroup rootBY = rootB.getOrCreateSubGroup("y"); + rootBY.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); + rootBY.setMaxQueuedQueries(10); + rootBY.setMaxRunningQueries(5); + + // Queue 40 queries (= maxQueuedQueries (40) + maxRunningQueries (0)) + Set queries = fillGroupTo(rootAX, ImmutableSet.of(), 10, false); + queries.addAll(fillGroupTo(rootAY, ImmutableSet.of(), 10, false)); + queries.addAll(fillGroupTo(rootBX, ImmutableSet.of(), 10, true)); + queries.addAll(fillGroupTo(rootBY, ImmutableSet.of(), 10, true)); + + assertEquals(root.getWaitingQueuedQueries(), 16); + assertEquals(rootA.getWaitingQueuedQueries(), 13); + assertEquals(rootAX.getWaitingQueuedQueries(), 10); + assertEquals(rootAY.getWaitingQueuedQueries(), 10); + assertEquals(rootB.getWaitingQueuedQueries(), 13); + assertEquals(rootBX.getWaitingQueuedQueries(), 10); + assertEquals(rootBY.getWaitingQueuedQueries(), 10); + + root.setMaxRunningQueries(20); + root.processQueuedQueries(); + assertEquals(root.getWaitingQueuedQueries(), 0); + assertEquals(rootA.getWaitingQueuedQueries(), 5); + assertEquals(rootAX.getWaitingQueuedQueries(), 6); + assertEquals(rootAY.getWaitingQueuedQueries(), 6); + assertEquals(rootB.getWaitingQueuedQueries(), 5); + assertEquals(rootBX.getWaitingQueuedQueries(), 6); + assertEquals(rootBY.getWaitingQueuedQueries(), 6); + } + private static Set fillGroupTo(InternalResourceGroup group, Set existingQueries, int count) { return fillGroupTo(group, existingQueries, count, false); diff --git a/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestUpdateablePriorityQueue.java b/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestUpdateablePriorityQueue.java new file mode 100644 index 0000000000000..a17a82bd46191 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/resourceGroups/TestUpdateablePriorityQueue.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.resourceGroups; + +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestUpdateablePriorityQueue +{ + @Test + public void testFifoQueue() + { + assertEquals(populateAndExtract(new FifoQueue<>()), ImmutableList.of(1, 2, 3)); + } + + @Test + public void testIndexedPriorityQueue() + { + assertEquals(populateAndExtract(new IndexedPriorityQueue<>()), ImmutableList.of(3, 2, 1)); + } + + @Test + public void testStochasticPriorityQueue() + { + assertTrue(populateAndExtract(new StochasticPriorityQueue<>()).size() == 3); + } + + private static List populateAndExtract(UpdateablePriorityQueue queue) + { + queue.addOrUpdate(1, 1); + queue.addOrUpdate(2, 2); + queue.addOrUpdate(3, 3); + return ImmutableList.copyOf(queue); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java index 167ce2af49f59..51d473d3eedad 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java @@ -21,8 +21,6 @@ import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TestingColumnHandle; -import com.facebook.presto.sql.planner.TestingTableHandle; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -30,6 +28,8 @@ import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.UnionNode; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; diff --git a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java index 158f6961d98d6..759dadbccf3e2 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java @@ -43,13 +43,13 @@ import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.StageExecutionPlan; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TestingColumnHandle; -import com.facebook.presto.sql.planner.TestingTableHandle; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanFragmentId; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.facebook.presto.testing.TestingSplit; import com.facebook.presto.testing.TestingTransactionHandle; import com.facebook.presto.util.FinalizerService; diff --git a/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java b/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java index 41147d9f99667..9f28e676afe68 100644 --- a/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java +++ b/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java @@ -14,6 +14,7 @@ package com.facebook.presto.failureDetector; import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.server.InternalCommunicationConfig; import com.google.inject.Binder; import com.google.inject.Injector; import com.google.inject.Key; @@ -59,6 +60,7 @@ public void testExcludesCurrentNode() @Override public void configure(Binder binder) { + configBinder(binder).bindConfig(InternalCommunicationConfig.class); configBinder(binder).bindConfig(QueryManagerConfig.class); discoveryBinder(binder).bindSelector("presto"); discoveryBinder(binder).bindHttpAnnouncement("presto"); diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/DummyMetadata.java b/presto-main/src/test/java/com/facebook/presto/metadata/DummyMetadata.java new file mode 100644 index 0000000000000..480bcbd7e7cac --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/metadata/DummyMetadata.java @@ -0,0 +1,415 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.spi.CatalogSchemaName; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnIdentity; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.TableIdentity; +import com.facebook.presto.spi.block.BlockEncodingSerde; +import com.facebook.presto.spi.connector.ConnectorOutputMetadata; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.security.GrantInfo; +import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.tree.QualifiedName; +import io.airlift.slice.Slice; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Set; + +public class DummyMetadata + implements Metadata +{ + @Override + public void verifyComparableOrderableContract() + { + throw new UnsupportedOperationException(); + } + + @Override + public Type getType(TypeSignature signature) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isAggregationFunction(QualifiedName name) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listFunctions() + { + throw new UnsupportedOperationException(); + } + + @Override + public void addFunctions(List functions) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean schemaExists(Session session, CatalogSchemaName schema) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listSchemaNames(Session session, String catalogName) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getTableHandle(Session session, QualifiedObjectName tableName) + { + throw new UnsupportedOperationException(); + } + + @Override + public List getLayouts( + Session session, + TableHandle tableHandle, + Constraint constraint, + Optional> desiredColumns) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableLayout getLayout(Session session, TableLayoutHandle handle) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getInfo(Session session, TableLayoutHandle handle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableMetadata getTableMetadata(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableStatistics getTableStatistics(Session session, TableHandle tableHandle, Constraint constraint) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listTables(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getColumnHandles(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map> listTableColumns(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public void createSchema(Session session, CatalogSchemaName schema, Map properties) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dropSchema(Session session, CatalogSchemaName schema) + { + throw new UnsupportedOperationException(); + } + + @Override + public void renameSchema(Session session, CatalogSchemaName source, String target) + { + throw new UnsupportedOperationException(); + } + + @Override + public void createTable(Session session, String catalogName, ConnectorTableMetadata tableMetadata) + { + throw new UnsupportedOperationException(); + } + + @Override + public void renameTable(Session session, TableHandle tableHandle, QualifiedObjectName newTableName) + { + throw new UnsupportedOperationException(); + } + + @Override + public void renameColumn(Session session, TableHandle tableHandle, ColumnHandle source, String target) + { + throw new UnsupportedOperationException(); + } + + @Override + public void addColumn(Session session, TableHandle tableHandle, ColumnMetadata column) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dropTable(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableIdentity getTableIdentity(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableIdentity deserializeTableIdentity(Session session, String catalogName, byte[] bytes) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnIdentity getColumnIdentity(Session session, TableHandle tableHandle, ColumnHandle columnHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnIdentity deserializeColumnIdentity(Session session, String catalogName, byte[] bytes) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getNewTableLayout(Session session, String catalogName, ConnectorTableMetadata tableMetadata) + { + throw new UnsupportedOperationException(); + } + + @Override + public OutputTableHandle beginCreateTable(Session session, String catalogName, ConnectorTableMetadata tableMetadata, Optional layout) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional finishCreateTable(Session session, OutputTableHandle tableHandle, Collection fragments) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getInsertLayout(Session session, TableHandle target) + { + throw new UnsupportedOperationException(); + } + + @Override + public void beginQuery(Session session, Set connectors) + { + throw new UnsupportedOperationException(); + } + + @Override + public void cleanupQuery(Session session) + { + throw new UnsupportedOperationException(); + } + + @Override + public InsertTableHandle beginInsert(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional finishInsert(Session session, InsertTableHandle tableHandle, Collection fragments) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean supportsMetadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public OptionalLong metadataDelete(Session session, TableHandle tableHandle, TableLayoutHandle tableLayoutHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public TableHandle beginDelete(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public void finishDelete(Session session, TableHandle tableHandle, Collection fragments) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getCatalogHandle(Session session, String catalogName) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getCatalogNames(Session session) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listViews(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getViews(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getView(Session session, QualifiedObjectName viewName) + { + throw new UnsupportedOperationException(); + } + + @Override + public void createView(Session session, QualifiedObjectName viewName, String viewData, boolean replace) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dropView(Session session, QualifiedObjectName viewName) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional resolveIndex( + Session session, + TableHandle tableHandle, + Set indexableColumns, + Set outputColumns, + TupleDomain tupleDomain) + { + throw new UnsupportedOperationException(); + } + + @Override + public void grantTablePrivileges(Session session, QualifiedObjectName tableName, Set privileges, String grantee, boolean grantOption) + { + throw new UnsupportedOperationException(); + } + + @Override + public void revokeTablePrivileges(Session session, QualifiedObjectName tableName, Set privileges, String grantee, boolean grantOption) + { + throw new UnsupportedOperationException(); + } + + @Override + public List listTablePrivileges(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public FunctionRegistry getFunctionRegistry() + { + throw new UnsupportedOperationException(); + } + + @Override + public ProcedureRegistry getProcedureRegistry() + { + throw new UnsupportedOperationException(); + } + + @Override + public TypeManager getTypeManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public BlockEncodingSerde getBlockEncodingSerde() + { + throw new UnsupportedOperationException(); + } + + @Override + public SessionPropertyManager getSessionPropertyManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public SchemaPropertyManager getSchemaPropertyManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public TablePropertyManager getTablePropertyManager() + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java b/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java index 0951026ac06c8..9faa2bcbea120 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java @@ -14,6 +14,7 @@ package com.facebook.presto.metadata; import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.server.NoOpFailureDetector; import com.facebook.presto.spi.Node; import com.google.common.collect.ArrayListMultimap; @@ -49,6 +50,7 @@ public class TestDiscoveryNodeManager { private final NodeInfo nodeInfo = new NodeInfo("test"); + private final InternalCommunicationConfig internalCommunicationConfig = new InternalCommunicationConfig(); private NodeVersion expectedVersion; private List activeNodes; private List inactiveNodes; @@ -90,7 +92,7 @@ public void setup() public void testGetAllNodes() throws Exception { - DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient); + DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); AllNodes allNodes = manager.getAllNodes(); Set activeNodes = allNodes.getActiveNodes(); @@ -125,7 +127,7 @@ public void testGetCurrentNode() .setEnvironment("test") .setNodeId(expected.getNodeIdentifier())); - DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient); + DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); assertEquals(manager.getCurrentNode(), expected); } @@ -134,7 +136,7 @@ public void testGetCurrentNode() public void testGetCoordinators() throws Exception { - InternalNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient); + InternalNodeManager manager = new DiscoveryNodeManager(selector, nodeInfo, new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); assertEquals(manager.getCoordinators(), ImmutableSet.of(coordinator)); } @@ -142,6 +144,6 @@ public void testGetCoordinators() @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = ".* current node not returned .*") public void testGetCurrentNodeRequired() { - new DiscoveryNodeManager(selector, new NodeInfo("test"), new NoOpFailureDetector(), expectedVersion, testHttpClient); + new DiscoveryNodeManager(selector, new NodeInfo("test"), new NoOpFailureDetector(), expectedVersion, testHttpClient, internalCommunicationConfig); } } diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/TestSignatureBinder.java b/presto-main/src/test/java/com/facebook/presto/metadata/TestSignatureBinder.java index 3f6b325debfe4..c796c1eb4d2d8 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/TestSignatureBinder.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/TestSignatureBinder.java @@ -13,12 +13,15 @@ */ package com.facebook.presto.metadata; -import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; +import com.facebook.presto.type.FunctionType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -33,9 +36,11 @@ import static com.facebook.presto.metadata.Signature.comparableTypeParameter; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.metadata.Signature.withVariadicBound; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; @@ -50,7 +55,11 @@ public class TestSignatureBinder { - private final TypeRegistry typeRegistry = new TypeRegistry(); + private final TypeManager typeRegistry = new TypeRegistry(); + { + // associate typeRegistry with a function registry + new FunctionRegistry(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); + } @Test public void testBindLiteralForDecimal() @@ -1002,7 +1011,7 @@ public void testFunction() assertThat(simple) .boundTo("function(integer,integer)") .succeeds(); - // TODO: This should eventually be supported + // TODO: Support coercion of return type of lambda assertThat(simple) .boundTo("function(integer,smallint)") .withCoercion() @@ -1077,6 +1086,26 @@ public void testFunction() assertThat(varargApply) .boundTo("integer", "function(integer, integer)", "function(integer, double)", "function(double, double)") .fails(); + + Signature loop = functionSignature() + .returnType(parseTypeSignature("T")) + .argumentTypes(parseTypeSignature("T"), parseTypeSignature("function(T, T)")) + .typeVariableConstraints(typeVariable("T")) + .build(); + assertThat(loop) + .boundTo("integer", new TypeSignatureProvider(paramTypes -> new FunctionType(paramTypes, BIGINT).getTypeSignature())) + .fails(); + assertThat(loop) + .boundTo("integer", new TypeSignatureProvider(paramTypes -> new FunctionType(paramTypes, BIGINT).getTypeSignature())) + .withCoercion() + .produces(BoundVariables.builder() + .setTypeVariable("T", BIGINT) + .build()); + // TODO: Support coercion of return type of lambda + assertThat(loop) + .withCoercion() + .boundTo("integer", new TypeSignatureProvider(paramTypes -> new FunctionType(paramTypes, SMALLINT).getTypeSignature())) + .fails(); } @Test @@ -1085,7 +1114,7 @@ public void testBindParameters() { BoundVariables boundVariables = BoundVariables.builder() .setTypeVariable("T1", DOUBLE) - .setTypeVariable("T2", BigintType.BIGINT) + .setTypeVariable("T2", BIGINT) .setTypeVariable("T3", DecimalType.createDecimalType(5, 3)) .setLongVariable("p", 1L) .setLongVariable("s", 2L) @@ -1171,12 +1200,6 @@ public BindSignatureAssertion withCoercion() return this; } - public BindSignatureAssertion boundTo(String... arguments) - { - this.argumentTypes = fromTypes(types(arguments)); - return this; - } - public BindSignatureAssertion boundTo(Object... arguments) { ImmutableList.Builder builder = ImmutableList.builder(); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java index 12a6a193891f5..f38e7b0997e8b 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java @@ -54,6 +54,7 @@ import static java.lang.String.format; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.openjdk.jmh.annotations.Mode.AverageTime; import static org.openjdk.jmh.annotations.Scope.Thread; @@ -63,7 +64,7 @@ @BenchmarkMode(AverageTime) @Fork(3) @Warmup(iterations = 5) -@Measurement(iterations = 20) +@Measurement(iterations = 10, time = 2, timeUnit = SECONDS) public class BenchmarkHashBuildAndJoinOperators { private static final int HASH_BUILD_OPERATOR_ID = 1; @@ -75,16 +76,16 @@ public class BenchmarkHashBuildAndJoinOperators public static class BuildContext { protected static final int ROWS_PER_PAGE = 1024; - protected static final int BUILD_ROWS_NUMBER = 700_000; + protected static final int BUILD_ROWS_NUMBER = 8_000_000; @Param({"varchar", "bigint", "all"}) - protected String hashColumns; + protected String hashColumns = "bigint"; @Param({"false", "true"}) - protected boolean buildHashEnabled; + protected boolean buildHashEnabled = false; @Param({"1", "5"}) - protected int buildRowsRepetition; + protected int buildRowsRepetition = 1; protected ExecutorService executor; protected List buildPages; @@ -161,13 +162,13 @@ protected void initializeBuildPages() public static class JoinContext extends BuildContext { - protected static final int PROBE_ROWS_NUMBER = 700_000; + protected static final int PROBE_ROWS_NUMBER = 1_400_000; @Param({"0.1", "1", "2"}) - protected double matchRate; + protected double matchRate = 1; @Param({"bigint", "all"}) - protected String outputColumns; + protected String outputColumns = "bigint"; protected List probePages; protected List outputChannels; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java b/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java index 0820c9a3670eb..c02b2d58a2b61 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/OperatorAssertion.java @@ -23,19 +23,25 @@ import com.facebook.presto.testing.MaterializedResult; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ListenableFuture; import java.util.ArrayList; +import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; import static com.facebook.presto.operator.PageAssertions.assertPageEquals; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.facebook.presto.type.TypeJsonUtils.appendToBlockBuilder; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; public final class OperatorAssertion { @@ -45,96 +51,84 @@ private OperatorAssertion() public static List toPages(Operator operator, Iterator input) { - ImmutableList.Builder outputPages = ImmutableList.builder(); + return ImmutableList.builder() + .addAll(toPagesPartial(operator, input)) + .addAll(finishOperator(operator)) + .build(); + } - boolean finishing = false; + public static List toPagesPartial(Operator operator, Iterator input) + { + // verify initial state + assertEquals(operator.isFinished(), false); - while (operator.needsInput() && input.hasNext()) { - operator.addInput(input.next()); - } + ImmutableList.Builder outputPages = ImmutableList.builder(); + for (int loopsSinceLastPage = 0; loopsSinceLastPage < 1_000; loopsSinceLastPage++) { + if (handledBlocked(operator)) { + continue; + } - for (int loops = 0; !operator.isFinished() && loops < 10_000; loops++) { - if (operator.needsInput()) { - if (input.hasNext()) { - Page inputPage = input.next(); - operator.addInput(inputPage); - } - else if (!finishing) { - operator.finish(); - finishing = true; - } + if (input.hasNext() && operator.needsInput()) { + operator.addInput(input.next()); + loopsSinceLastPage = 0; } Page outputPage = operator.getOutput(); - if (outputPage != null) { + if (outputPage != null && outputPage.getPositionCount() != 0) { outputPages.add(outputPage); + loopsSinceLastPage = 0; } } - assertFalse(operator.needsInput()); - assertTrue(operator.isBlocked().isDone()); - assertTrue(operator.isFinished()); - return outputPages.build(); } - public static List toPages(OperatorFactory operatorFactory, DriverContext driverContext, List input) + public static List finishOperator(Operator operator) { - try (Operator operator = operatorFactory.createOperator(driverContext)) { - return toPages(operator, input); - } - catch (Exception e) { - throw Throwables.propagate(e); + ImmutableList.Builder outputPages = ImmutableList.builder(); + + for (int loopsSinceLastPage = 0; !operator.isFinished() && loopsSinceLastPage < 1_000; loopsSinceLastPage++) { + if (handledBlocked(operator)) { + continue; + } + operator.finish(); + Page outputPage = operator.getOutput(); + if (outputPage != null && outputPage.getPositionCount() != 0) { + outputPages.add(outputPage); + loopsSinceLastPage = 0; + } } + + assertEquals(operator.isFinished(), true, "Operator did not finish"); + assertEquals(operator.needsInput(), false, "Operator still wants input"); + assertEquals(operator.isBlocked().isDone(), true, "Operator is blocked"); + + return outputPages.build(); } - private static List toPages(Operator operator, List input) + private static boolean handledBlocked(Operator operator) { - // verify initial state - assertEquals(operator.isFinished(), false); - assertEquals(operator.needsInput(), true); - assertEquals(operator.getOutput(), null); - - return toPages(operator, input.iterator()); + ListenableFuture isBlocked = operator.isBlocked(); + if (!isBlocked.isDone()) { + tryGetFutureValue(isBlocked, 1, TimeUnit.MILLISECONDS); + return true; + } + return false; } - public static List toPages(OperatorFactory operatorFactory, DriverContext driverContext) + public static List toPages(OperatorFactory operatorFactory, DriverContext driverContext, List input) { try (Operator operator = operatorFactory.createOperator(driverContext)) { - return toPages(operator); + return toPages(operator, input.iterator()); } catch (Exception e) { throw Throwables.propagate(e); } } - private static List toPages(Operator operator) - { - // operator does not have input so should never require input - assertEquals(operator.needsInput(), false); - - ImmutableList.Builder outputPages = ImmutableList.builder(); - addRemainingOutputPages(operator, outputPages); - return outputPages.build(); - } - - private static void addRemainingOutputPages(Operator operator, ImmutableList.Builder outputPages) + public static List toPages(OperatorFactory operatorFactory, DriverContext driverContext) { - // pull remaining output pages - while (!operator.isFinished()) { - // at this point the operator should not need more input - assertEquals(operator.needsInput(), false); - - Page outputPage = operator.getOutput(); - if (outputPage != null) { - outputPages.add(outputPage); - } - } - - // verify final state - assertEquals(operator.isFinished(), true); - assertEquals(operator.needsInput(), false); - assertEquals(operator.getOutput(), null); + return toPages(operatorFactory, driverContext, ImmutableList.of()); } public static MaterializedResult toMaterializedResult(Session session, List types, List pages) @@ -223,15 +217,14 @@ public static void assertOperatorEqualsIgnoreOrder( assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); } - static List without(List types, List channels) + static List without(List list, Collection indexes) { - types = new ArrayList<>(types); - int removed = 0; - for (int hashChannel : channels) { - types.remove(hashChannel - removed); - removed++; - } - return ImmutableList.copyOf(types); + Set indexesSet = ImmutableSet.copyOf(indexes); + + return IntStream.range(0, list.size()) + .filter(index -> !indexesSet.contains(index)) + .mapToObj(list::get) + .collect(toImmutableList()); } static List dropChannel(List pages, List channels) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java index 1b183e93255db..7dbf77349159a 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/PageAssertions.java @@ -31,6 +31,7 @@ public static void assertPageEquals(List types, Page actualPage, { assertEquals(types.size(), actualPage.getChannelCount()); assertEquals(actualPage.getChannelCount(), expectedPage.getChannelCount()); + assertEquals(actualPage.getPositionCount(), expectedPage.getPositionCount()); for (int i = 0; i < actualPage.getChannelCount(); i++) { assertBlockEquals(types.get(i), actualPage.getBlock(i), expectedPage.getBlock(i)); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java index f6c6550be3ee9..c413a223bb992 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java @@ -648,8 +648,8 @@ public void testOuterJoinWithNullOnBothSidesAndFilterFunction(boolean parallelBu assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages)); } - @Test(expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded local memory limit of.*", dataProvider = "hashEnabledValues") - public void testMemoryLimit(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) + @Test(expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded local memory limit of.*", dataProvider = "testMemoryLimitProvider") + public void testMemoryLimit(boolean parallelBuild, boolean buildHashEnabled) throws Exception { TaskContext taskContext = TestingTaskContext.createTaskContext(executor, TEST_SESSION, new DataSize(100, BYTE)); @@ -659,6 +659,16 @@ public void testMemoryLimit(boolean parallelBuild, boolean probeHashEnabled, boo buildHash(parallelBuild, taskContext, Ints.asList(0), buildPages, Optional.empty()); } + @DataProvider + public static Object[][] testMemoryLimitProvider() + { + return new Object[][] { + {true, true}, + {true, false}, + {false, true}, + {false, false}}; + } + private TaskContext createTaskContext() { return TestingTaskContext.createTaskContext(executor, TEST_SESSION); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java index 3a1e0b91dc1e4..d26788089175c 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java @@ -175,8 +175,7 @@ public void testBuildSideNulls(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - //Disabled till #6622 is fixed - @Test(dataProvider = "hashEnabledValues", enabled = false) + @Test(dataProvider = "hashEnabledValues") public void testProbeSideNulls(boolean hashEnabled) throws Exception { @@ -226,8 +225,7 @@ public void testProbeSideNulls(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - //Disabled till #6622 is fixed - @Test(dataProvider = "hashEnabledValues", enabled = false) + @Test(dataProvider = "hashEnabledValues") public void testProbeAndBuildNulls(boolean hashEnabled) throws Exception { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestJoinOperatorInfo.java b/presto-main/src/test/java/com/facebook/presto/operator/TestJoinOperatorInfo.java new file mode 100644 index 0000000000000..2e03c8df2332e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestJoinOperatorInfo.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.operator.LookupJoinOperators.JoinType; +import org.testng.annotations.Test; + +import static com.google.common.base.Preconditions.checkArgument; +import static org.testng.Assert.assertEquals; + +public class TestJoinOperatorInfo +{ + @Test + public void testMerge() + { + JoinOperatorInfo base = new JoinOperatorInfo( + JoinType.INNER, + makeHistogramArray(10, 20, 30, 40, 50, 60, 70, 80), + makeHistogramArray(12, 22, 32, 42, 52, 62, 72, 82)); + JoinOperatorInfo other = new JoinOperatorInfo( + JoinType.INNER, + makeHistogramArray(11, 21, 31, 41, 51, 61, 71, 81), + makeHistogramArray(15, 25, 35, 45, 55, 65, 75, 85)); + + JoinOperatorInfo merged = base.mergeWith(other); + assertEquals(makeHistogramArray(21, 41, 61, 81, 101, 121, 141, 161), merged.getLogHistogramProbes()); + assertEquals(makeHistogramArray(27, 47, 67, 87, 107, 127, 147, 167), merged.getLogHistogramOutput()); + } + + private long[] makeHistogramArray(long... longArray) + { + checkArgument(longArray.length == 8); + return longArray; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestJoinStatisticsCounter.java b/presto-main/src/test/java/com/facebook/presto/operator/TestJoinStatisticsCounter.java new file mode 100644 index 0000000000000..5bdc3bf032a6f --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestJoinStatisticsCounter.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.operator.LookupJoinOperators.JoinType; +import org.testng.annotations.Test; + +import static com.google.common.base.Preconditions.checkArgument; +import static org.testng.Assert.assertEquals; + +public class TestJoinStatisticsCounter +{ + @Test + public void testRecord() + { + JoinStatisticsCounter counter = new JoinStatisticsCounter(JoinType.INNER); + JoinOperatorInfo info = counter.get(); + assertEquals(makeHistogramArray(0, 0, 0, 0, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 0, 0, 0, 0, 0, 0, 0), info.getLogHistogramOutput()); + + // 0 to 4 buckets + counter.recordProbe(0); + info = counter.get(); + assertEquals(makeHistogramArray(1, 0, 0, 0, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 0, 0, 0, 0, 0, 0, 0), info.getLogHistogramOutput()); + counter.recordProbe(0); + info = counter.get(); + assertEquals(makeHistogramArray(2, 0, 0, 0, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 0, 0, 0, 0, 0, 0, 0), info.getLogHistogramOutput()); + + counter.recordProbe(1); + info = counter.get(); + assertEquals(makeHistogramArray(2, 1, 0, 0, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 1, 0, 0, 0, 0, 0, 0), info.getLogHistogramOutput()); + counter.recordProbe(1); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 0, 0, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 0, 0, 0, 0, 0, 0), info.getLogHistogramOutput()); + + counter.recordProbe(2); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 1, 0, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 2, 0, 0, 0, 0, 0), info.getLogHistogramOutput()); + counter.recordProbe(2); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 0, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 0, 0, 0, 0, 0), info.getLogHistogramOutput()); + + counter.recordProbe(3); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 1, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 3, 0, 0, 0, 0), info.getLogHistogramOutput()); + counter.recordProbe(3); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 0, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 0, 0, 0, 0), info.getLogHistogramOutput()); + + counter.recordProbe(4); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 1, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 4, 0, 0, 0), info.getLogHistogramOutput()); + counter.recordProbe(4); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 0, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 0, 0, 0), info.getLogHistogramOutput()); + + // 5 to 10 + counter.recordProbe(5); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 1, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 5, 0, 0), info.getLogHistogramOutput()); + counter.recordProbe(6); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 2, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 11, 0, 0), info.getLogHistogramOutput()); + counter.recordProbe(10); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 3, 0, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 21, 0, 0), info.getLogHistogramOutput()); + + // 11 to 100 + counter.recordProbe(11); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 3, 1, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 21, 11, 0), info.getLogHistogramOutput()); + counter.recordProbe(100); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 3, 2, 0), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 21, 111, 0), info.getLogHistogramOutput()); + + // 101 and more + counter.recordProbe(101); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 3, 2, 1), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 21, 111, 101), info.getLogHistogramOutput()); + counter.recordProbe(1000); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 3, 2, 2), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 21, 111, 1101), info.getLogHistogramOutput()); + counter.recordProbe(1000000); + info = counter.get(); + assertEquals(makeHistogramArray(2, 2, 2, 2, 2, 3, 2, 3), info.getLogHistogramProbes()); + assertEquals(makeHistogramArray(0, 2, 4, 6, 8, 21, 111, 1001101), info.getLogHistogramOutput()); + } + + private long[] makeHistogramArray(long... longArray) + { + checkArgument(longArray.length == 8); + return longArray; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestOperatorAssertion.java b/presto-main/src/test/java/com/facebook/presto/operator/TestOperatorAssertion.java new file mode 100644 index 0000000000000..925a49909bb68 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestOperatorAssertion.java @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.testing.assertions.Assert; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.units.Duration; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import static java.util.Collections.emptyIterator; +import static java.util.Objects.requireNonNull; + +public class TestOperatorAssertion +{ + private ScheduledExecutorService executor; + + @BeforeClass + public void setUp() + { + executor = Executors.newScheduledThreadPool(1); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + } + + @Test + public void testToPagesWithBlockedOperator() + { + Operator operator = new BlockedOperator(Duration.valueOf("15 ms")); + List pages = OperatorAssertion.toPages(operator, emptyIterator()); + Assert.assertEquals(pages, ImmutableList.of()); + } + + private class BlockedOperator + implements Operator + { + private final Duration unblockAfter; + private final OperatorContext operatorContext; + + private ListenableFuture isBlocked = NOT_BLOCKED; + + public BlockedOperator(Duration unblockAfter) + { + this.unblockAfter = requireNonNull(unblockAfter, "unblockAfter is null"); + this.operatorContext = TestingOperatorContext.create(); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public List getTypes() + { + throw new UnsupportedOperationException(); + } + + @Override + public ListenableFuture isBlocked() + { + return isBlocked; + } + + @Override + public boolean needsInput() + { + return false; + } + + @Override + public void addInput(Page page) + { + throw new UnsupportedOperationException(); + } + + @Override + public void finish() + { + if (this.isBlocked == NOT_BLOCKED) { + SettableFuture isBlocked = SettableFuture.create(); + this.isBlocked = isBlocked; + executor.schedule(() -> isBlocked.set(null), unblockAfter.toMillis(), TimeUnit.MILLISECONDS); + } + } + + @Override + public boolean isFinished() + { + return isBlocked != NOT_BLOCKED // finish() not called yet + && isBlocked.isDone(); + } + + @Override + public Page getOutput() + { + return null; + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestPipelineStats.java b/presto-main/src/test/java/com/facebook/presto/operator/TestPipelineStats.java index 67a0cf518b5fb..16c656ff971d3 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestPipelineStats.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestPipelineStats.java @@ -47,6 +47,7 @@ public class TestPipelineStats 1, 3, 2, + 19, 4, new DataSize(5, BYTE), @@ -98,6 +99,7 @@ public static void assertExpectedPipelineStats(PipelineStats actual) assertEquals(actual.getQueuedPartitionedDrivers(), 1); assertEquals(actual.getRunningDrivers(), 3); assertEquals(actual.getRunningPartitionedDrivers(), 2); + assertEquals(actual.getBlockedDrivers(), 19); assertEquals(actual.getCompletedDrivers(), 4); assertEquals(actual.getMemoryReservation(), new DataSize(5, BYTE)); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java b/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java index 6fcaec594c499..22160774b0e23 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java @@ -15,11 +15,14 @@ import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.spi.Page; -import it.unimi.dsi.fastutil.ints.IntComparator; +import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.longs.LongArrayList; import org.testng.annotations.Test; import java.util.Optional; +import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.collect.Iterables.getOnlyElement; import static org.testng.Assert.assertEquals; @@ -31,16 +34,16 @@ public class TestPositionLinks @Test public void testArrayPositionLinks() { - PositionLinks.Builder builder = ArrayPositionLinks.builder(1000); + PositionLinks.FactoryBuilder factoryBuilder = ArrayPositionLinks.builder(1000); - assertEquals(builder.link(1, 0), 1); - assertEquals(builder.link(2, 1), 2); - assertEquals(builder.link(3, 2), 3); + assertEquals(factoryBuilder.link(1, 0), 1); + assertEquals(factoryBuilder.link(2, 1), 2); + assertEquals(factoryBuilder.link(3, 2), 3); - assertEquals(builder.link(11, 10), 11); - assertEquals(builder.link(12, 11), 12); + assertEquals(factoryBuilder.link(11, 10), 11); + assertEquals(factoryBuilder.link(12, 11), 12); - PositionLinks positionLinks = builder.build().apply(Optional.empty()); + PositionLinks positionLinks = factoryBuilder.build().create(Optional.empty()); assertEquals(positionLinks.start(3, 0, TEST_PAGE), 3); assertEquals(positionLinks.next(3, 0, TEST_PAGE), 2); @@ -73,8 +76,8 @@ public Optional getSortChannel() } }; - PositionLinks.Builder builder = buildSortedPositionLinks(); - PositionLinks positionLinks = builder.build().apply(Optional.of(filterFunction)); + PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); + PositionLinks positionLinks = factoryBuilder.build().create(Optional.of(filterFunction)); assertEquals(positionLinks.start(0, 0, TEST_PAGE), 5); assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); @@ -104,8 +107,8 @@ public Optional getSortChannel() } }; - PositionLinks.Builder builder = buildSortedPositionLinks(); - PositionLinks positionLinks = builder.build().apply(Optional.of(filterFunction)); + PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); + PositionLinks positionLinks = factoryBuilder.build().create(Optional.of(filterFunction)); assertEquals(positionLinks.start(0, 0, TEST_PAGE), 0); assertEquals(positionLinks.next(0, 0, TEST_PAGE), 1); @@ -116,23 +119,12 @@ public Optional getSortChannel() assertEquals(positionLinks.start(10, 0, TEST_PAGE), -1); } - private static PositionLinks.Builder buildSortedPositionLinks() + private static PositionLinks.FactoryBuilder buildSortedPositionLinks() { - SortedPositionLinks.Builder builder = SortedPositionLinks.builder( + SortedPositionLinks.FactoryBuilder builder = SortedPositionLinks.builder( 1000, - new IntComparator() { - @Override - public int compare(int left, int right) - { - return BIGINT.compareTo(TEST_PAGE.getBlock(0), left, TEST_PAGE.getBlock(0), right); - } - - @Override - public int compare(Integer left, Integer right) - { - return compare(left.intValue(), right.intValue()); - } - }); + pagesHashStrategy(), + addresses()); assertEquals(builder.link(4, 5), 4); assertEquals(builder.link(6, 4), 4); @@ -146,4 +138,24 @@ public int compare(Integer left, Integer right) return builder; } + + private static PagesHashStrategy pagesHashStrategy() + { + return new SimplePagesHashStrategy( + ImmutableList.of(BIGINT), + ImmutableList.of(), + ImmutableList.of(ImmutableList.of(TEST_PAGE.getBlock(0))), + ImmutableList.of(), + Optional.empty(), + Optional.of(new SortExpression(0))); + } + + private static LongArrayList addresses() + { + LongArrayList addresses = new LongArrayList(); + for (int i = 0; i < TEST_PAGE.getPositionCount(); ++i) { + addresses.add(encodeSyntheticAddress(0, i)); + } + return addresses; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestTaskStats.java b/presto-main/src/test/java/com/facebook/presto/operator/TestTaskStats.java index 70de54aa7dbd0..adc30f635efdb 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestTaskStats.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestTaskStats.java @@ -43,6 +43,7 @@ public class TestTaskStats 5, 8, 6, + 24, 10, 11.0, @@ -92,6 +93,7 @@ public static void assertExpectedTaskStats(TaskStats actual) assertEquals(actual.getQueuedPartitionedDrivers(), 5); assertEquals(actual.getRunningDrivers(), 8); assertEquals(actual.getRunningPartitionedDrivers(), 6); + assertEquals(actual.getBlockedDrivers(), 24); assertEquals(actual.getCompletedDrivers(), 10); assertEquals(actual.getCumulativeMemory(), 11.0); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestTopNOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestTopNOperator.java index 7567b24a42748..17e0f79b0bad7 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestTopNOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestTopNOperator.java @@ -185,16 +185,9 @@ public void testLimitZero() new DataSize(16, MEGABYTE)); try (Operator operator = factory.createOperator(driverContext)) { - MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT).build(); - - // assertOperatorEquals assumes operators do not start in finished state assertEquals(operator.isFinished(), true); assertEquals(operator.needsInput(), false); assertEquals(operator.getOutput(), null); - - List pages = OperatorAssertion.toPages(operator, input.iterator()); - MaterializedResult actual = OperatorAssertion.toMaterializedResult(operator.getOperatorContext().getSession(), operator.getTypes(), pages); - assertEquals(actual, expected); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java index 80fe230f70273..77820223ea81e 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestUnnestOperator.java @@ -15,10 +15,10 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.testing.MaterializedResult; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterMethod; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java index 1b6f189be3428..5f7bcc2ade81b 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestWindowOperator.java @@ -57,6 +57,7 @@ import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.testng.Assert.assertEquals; @Test(singleThreaded = true) public class TestWindowOperator @@ -595,6 +596,39 @@ public void testFullyPreGroupedAndFullySortedPartition() assertOperatorEquals(operatorFactory, driverContext, input, expected); } + @Test + public void testFindEndPosition() + { + assertFindEndPosition("0", 1); + assertFindEndPosition("11", 2); + assertFindEndPosition("1111111111", 10); + + assertFindEndPosition("01", 1); + assertFindEndPosition("011", 1); + assertFindEndPosition("0111", 1); + assertFindEndPosition("0111111111", 1); + + assertFindEndPosition("012", 1); + assertFindEndPosition("01234", 1); + assertFindEndPosition("0123456789", 1); + + assertFindEndPosition("001", 2); + assertFindEndPosition("0001", 3); + assertFindEndPosition("0000000001", 9); + + assertFindEndPosition("00100", 2); + assertFindEndPosition("000111", 3); + assertFindEndPosition("0001111", 3); + assertFindEndPosition("0000111", 4); + assertFindEndPosition("000000000000001111111111", 14); + } + + private static void assertFindEndPosition(String values, int expected) + { + char[] array = values.toCharArray(); + assertEquals(WindowOperator.findEndPosition(0, array.length, (first, second) -> array[first] == array[second]), expected); + } + private static WindowOperatorFactory createFactoryUnbounded( List sourceTypes, List outputChannels, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestingOperatorContext.java b/presto-main/src/test/java/com/facebook/presto/operator/TestingOperatorContext.java new file mode 100644 index 0000000000000..5ab7f2161abc3 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestingOperatorContext.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.testing.TestingSession; +import com.facebook.presto.testing.TestingTaskContext; +import com.google.common.util.concurrent.MoreExecutors; + +import java.util.concurrent.Executor; + +public class TestingOperatorContext +{ + public static OperatorContext create() + { + Executor executor = MoreExecutors.directExecutor(); + + TaskContext taskContext = TestingTaskContext.createTaskContext( + executor, + TestingSession.testSessionBuilder().build()); + + PipelineContext pipelineContext = new PipelineContext( + 1, + taskContext, + executor, + false, + false + ); + + DriverContext driverContext = new DriverContext( + pipelineContext, + executor, + false + ); + + OperatorContext operatorContext = driverContext.addOperatorContext( + 1, + new PlanNodeId("test"), + "operator type" + ); + + return operatorContext; + } + + private TestingOperatorContext() {} +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java new file mode 100644 index 0000000000000..115860fa28a1f --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java @@ -0,0 +1,167 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; +import org.openjdk.jmh.runner.options.WarmupMode; + +import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static org.openjdk.jmh.annotations.Level.Invocation; + +@SuppressWarnings("MethodMayBeStatic") +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(2) +@Warmup(iterations = 10, time = 2, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 10, time = 2, timeUnit = TimeUnit.SECONDS) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkArrayAggregation +{ + private static final int ARRAY_SIZE = 10_000_000; + + @Benchmark + @OperationsPerInvocation(ARRAY_SIZE) + public void arrayAggregation(BenchmarkData data) + throws Throwable + { + data.getAccumulator().addInput(data.getPage()); + } + + @SuppressWarnings("FieldMayBeFinal") + @State(Scope.Thread) + public static class BenchmarkData + { + private String name = "array_agg"; + + @Param({"BIGINT", "VARCHAR", "DOUBLE", "BOOLEAN"}) + private String type = "BIGINT"; + + private Page page; + private Accumulator accumulator; + + @Setup(Invocation) + public void setup() + { + MetadataManager metadata = MetadataManager.createTestMetadataManager(); + Block block; + Type elementType; + switch (type) { + case "BIGINT": + elementType = BIGINT; + break; + case "VARCHAR": + elementType = VARCHAR; + break; + case "DOUBLE": + elementType = DOUBLE; + break; + case "BOOLEAN": + elementType = BOOLEAN; + break; + default: + throw new UnsupportedOperationException(); + } + ArrayType arrayType = new ArrayType(elementType); + Signature signature = new Signature(name, AGGREGATE, arrayType.getTypeSignature(), elementType.getTypeSignature()); + InternalAggregationFunction function = metadata.getFunctionRegistry().getAggregateFunctionImplementation(signature); + accumulator = function.bind(ImmutableList.of(0), Optional.empty()).createAccumulator(); + + block = createChannel(ARRAY_SIZE, elementType); + page = new Page(block); + } + + private static Block createChannel(int arraySize, Type elementType) + { + BlockBuilder blockBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), arraySize); + for (int i = 0; i < arraySize; i++) { + if (elementType.getJavaType() == long.class) { + elementType.writeLong(blockBuilder, (long) i); + } + else if (elementType.getJavaType() == double.class) { + elementType.writeDouble(blockBuilder, ThreadLocalRandom.current().nextDouble()); + } + else if (elementType.getJavaType() == boolean.class) { + elementType.writeBoolean(blockBuilder, ThreadLocalRandom.current().nextBoolean()); + } + else if (elementType.equals(VARCHAR)) { + // make sure the size of a varchar is rather small; otherwise the aggregated slice may overflow + elementType.writeSlice(blockBuilder, Slices.utf8Slice(Long.toString(ThreadLocalRandom.current().nextLong() % 100))); + } + else { + throw new UnsupportedOperationException(); + } + } + return blockBuilder.build(); + } + + public Accumulator getAccumulator() + { + return accumulator; + } + + public Page getPage() + { + return page; + } + } + + public static void main(String[] args) + throws Throwable + { + // assure the benchmarks are valid before running + BenchmarkData data = new BenchmarkData(); + data.setup(); + new BenchmarkArrayAggregation().arrayAggregation(data); + + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .warmupMode(WarmupMode.BULK) + .include(".*" + BenchmarkArrayAggregation.class.getSimpleName() + ".*") + .build(); + new Runner(options).run(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java index 1c3b0cd3fcaf5..93ba6b0ea5c56 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java @@ -18,7 +18,7 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.RunLengthEncodedBlock; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java index 8d76cd96ab4ab..e3fc1315b7f7a 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMaxNAggregation.java @@ -17,8 +17,8 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java index 7defcc9502ef2..cbc72565508c0 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayMinAggregation.java @@ -16,7 +16,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import java.util.List; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java index 9f10c3c552307..0b24dbf988aaf 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java @@ -16,6 +16,7 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DecimalType; @@ -24,7 +25,6 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarbinaryType; import com.facebook.presto.spi.type.VarcharType; -import com.facebook.presto.type.ArrayType; import org.testng.annotations.Test; import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java index 1d013925d2a31..7d5c1ed6be992 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java @@ -20,9 +20,9 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; @@ -38,6 +38,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -126,7 +127,7 @@ public void testBadNumberOfBuckets() private static Map extractSingleValue(Block block) throws IOException { - MapType mapType = new MapType(DOUBLE, DOUBLE); + MapType mapType = mapType(DOUBLE, DOUBLE); return (Map) mapType.getObjectValue(null, block, 0); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java index 316ba88590ae2..9a905964b917d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java @@ -17,12 +17,12 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlTimestampWithTimeZone; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.TimeZoneKey; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.joda.time.DateTime; @@ -50,6 +50,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestHistogram { @@ -61,7 +62,7 @@ public class TestHistogram public void testSimpleHistograms() throws Exception { - MapType mapType = new MapType(VARCHAR, BIGINT); + MapType mapType = mapType(VARCHAR, BIGINT); InternalAggregationFunction aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -72,7 +73,7 @@ public void testSimpleHistograms() ImmutableMap.of("a", 1L, "b", 1L, "c", 1L), createStringsBlock("a", "b", "c")); - mapType = new MapType(BIGINT, BIGINT); + mapType = mapType(BIGINT, BIGINT); aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -83,7 +84,7 @@ public void testSimpleHistograms() ImmutableMap.of(100L, 1L, 200L, 1L, 300L, 1L), createLongsBlock(100L, 200L, 300L)); - mapType = new MapType(DOUBLE, BIGINT); + mapType = mapType(DOUBLE, BIGINT); aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -94,7 +95,7 @@ public void testSimpleHistograms() ImmutableMap.of(0.1, 1L, 0.3, 1L, 0.2, 1L), createDoublesBlock(0.1, 0.3, 0.2)); - mapType = new MapType(BOOLEAN, BIGINT); + mapType = mapType(BOOLEAN, BIGINT); aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -110,7 +111,7 @@ public void testSimpleHistograms() public void testDuplicateKeysValues() throws Exception { - MapType mapType = new MapType(VARCHAR, BIGINT); + MapType mapType = mapType(VARCHAR, BIGINT); InternalAggregationFunction aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -121,7 +122,7 @@ public void testDuplicateKeysValues() ImmutableMap.of("a", 2L, "b", 1L), createStringsBlock("a", "b", "a")); - mapType = new MapType(TIMESTAMP_WITH_TIME_ZONE, BIGINT); + mapType = mapType(TIMESTAMP_WITH_TIME_ZONE, BIGINT); aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -139,7 +140,7 @@ public void testDuplicateKeysValues() public void testWithNulls() throws Exception { - MapType mapType = new MapType(BIGINT, BIGINT); + MapType mapType = mapType(BIGINT, BIGINT); InternalAggregationFunction aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -150,7 +151,7 @@ public void testWithNulls() ImmutableMap.of(1L, 1L, 2L, 1L), createLongsBlock(2L, null, 1L)); - mapType = new MapType(BIGINT, BIGINT); + mapType = mapType(BIGINT, BIGINT); aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -167,7 +168,7 @@ public void testArrayHistograms() throws Exception { ArrayType arrayType = new ArrayType(VARCHAR); - MapType mapType = new MapType(arrayType, BIGINT); + MapType mapType = mapType(arrayType, BIGINT); InternalAggregationFunction aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -184,8 +185,8 @@ public void testArrayHistograms() public void testMapHistograms() throws Exception { - MapType innerMapType = new MapType(VARCHAR, VARCHAR); - MapType mapType = new MapType(innerMapType, BIGINT); + MapType innerMapType = mapType(VARCHAR, VARCHAR); + MapType mapType = mapType(innerMapType, BIGINT); InternalAggregationFunction aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -208,7 +209,7 @@ public void testRowHistograms() throws Exception { RowType innerRowType = new RowType(ImmutableList.of(BIGINT, DOUBLE), Optional.of(ImmutableList.of("f1", "f2"))); - MapType mapType = new MapType(innerRowType, BIGINT); + MapType mapType = mapType(innerRowType, BIGINT); InternalAggregationFunction aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -230,7 +231,7 @@ public void testRowHistograms() public void testLargerHistograms() throws Exception { - MapType mapType = new MapType(VARCHAR, BIGINT); + MapType mapType = mapType(VARCHAR, BIGINT); InternalAggregationFunction aggregationFunction = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java index 841747531d40f..e8fe4268fbb95 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java @@ -17,10 +17,10 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -44,6 +44,7 @@ import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestMapAggAggregation { @@ -53,7 +54,7 @@ public class TestMapAggAggregation public void testDuplicateKeysValues() throws Exception { - MapType mapType = new MapType(DOUBLE, VARCHAR); + MapType mapType = mapType(DOUBLE, VARCHAR); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -66,7 +67,7 @@ public void testDuplicateKeysValues() createDoublesBlock(1.0, 1.0, 1.0), createStringsBlock("a", "b", "c")); - mapType = new MapType(DOUBLE, INTEGER); + mapType = mapType(DOUBLE, INTEGER); aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -84,7 +85,7 @@ public void testDuplicateKeysValues() public void testSimpleMaps() throws Exception { - MapType mapType = new MapType(DOUBLE, VARCHAR); + MapType mapType = mapType(DOUBLE, VARCHAR); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -97,7 +98,7 @@ public void testSimpleMaps() createDoublesBlock(1.0, 2.0, 3.0), createStringsBlock("a", "b", "c")); - mapType = new MapType(DOUBLE, INTEGER); + mapType = mapType(DOUBLE, INTEGER); aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -110,7 +111,7 @@ public void testSimpleMaps() createDoublesBlock(1.0, 2.0, 3.0), createTypedLongsBlock(INTEGER, ImmutableList.of(3L, 2L, 1L))); - mapType = new MapType(DOUBLE, BOOLEAN); + mapType = mapType(DOUBLE, BOOLEAN); aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, @@ -131,7 +132,7 @@ public void testNull() InternalAggregationFunction doubleDouble = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, - new MapType(DOUBLE, DOUBLE).getTypeSignature(), + mapType(DOUBLE, DOUBLE).getTypeSignature(), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE))); assertAggregation( @@ -162,7 +163,7 @@ public void testDoubleArrayMap() throws Exception { ArrayType arrayType = new ArrayType(VARCHAR); - MapType mapType = new MapType(DOUBLE, arrayType); + MapType mapType = mapType(DOUBLE, arrayType); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), @@ -182,8 +183,8 @@ public void testDoubleArrayMap() public void testDoubleMapMap() throws Exception { - MapType innerMapType = new MapType(VARCHAR, VARCHAR); - MapType mapType = new MapType(DOUBLE, innerMapType); + MapType innerMapType = mapType(VARCHAR, VARCHAR); + MapType mapType = mapType(DOUBLE, innerMapType); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), @@ -209,7 +210,7 @@ public void testDoubleRowMap() throws Exception { RowType innerRowType = new RowType(ImmutableList.of(INTEGER, DOUBLE), Optional.of(ImmutableList.of("f1", "f2"))); - MapType mapType = new MapType(DOUBLE, innerRowType); + MapType mapType = mapType(DOUBLE, innerRowType); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), @@ -235,7 +236,7 @@ public void testArrayDoubleMap() throws Exception { ArrayType arrayType = new ArrayType(VARCHAR); - MapType mapType = new MapType(arrayType, DOUBLE); + MapType mapType = mapType(arrayType, DOUBLE); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature( NAME, AGGREGATE, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java index 3cfdc804735af..5423fa84725c1 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java @@ -15,8 +15,8 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.Signature; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -33,6 +33,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.arrayBlockOf; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Preconditions.checkArgument; public class TestMapUnionAggregation @@ -43,7 +44,7 @@ public class TestMapUnionAggregation public void testSimpleWithDuplicates() throws Exception { - MapType mapType = new MapType(DOUBLE, VARCHAR); + MapType mapType = mapType(DOUBLE, VARCHAR); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); assertAggregation( @@ -54,7 +55,7 @@ public void testSimpleWithDuplicates() mapBlockOf(DOUBLE, VARCHAR, ImmutableMap.of(23.0, "aaa", 33.0, "bbb", 53.0, "ddd")), mapBlockOf(DOUBLE, VARCHAR, ImmutableMap.of(43.0, "ccc", 53.0, "ddd", 13.0, "eee")))); - mapType = new MapType(DOUBLE, BIGINT); + mapType = mapType(DOUBLE, BIGINT); aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); assertAggregation( @@ -65,7 +66,7 @@ public void testSimpleWithDuplicates() mapBlockOf(DOUBLE, BIGINT, ImmutableMap.of(1.0, 99L, 2.0, 99L, 3.0, 99L)), mapBlockOf(DOUBLE, BIGINT, ImmutableMap.of(1.0, 44L, 2.0, 44L, 4.0, 44L)))); - mapType = new MapType(BOOLEAN, BIGINT); + mapType = mapType(BOOLEAN, BIGINT); aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); assertAggregation( @@ -81,7 +82,7 @@ public void testSimpleWithDuplicates() public void testSimpleWithNulls() throws Exception { - MapType mapType = new MapType(DOUBLE, VARCHAR); + MapType mapType = mapType(DOUBLE, VARCHAR); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); @@ -101,7 +102,7 @@ public void testSimpleWithNulls() public void testStructural() throws Exception { - MapType mapType = new MapType(DOUBLE, new ArrayType(VARCHAR)); + MapType mapType = mapType(DOUBLE, new ArrayType(VARCHAR)); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); assertAggregation( @@ -134,7 +135,7 @@ public void testStructural() 3.0, ImmutableList.of("w", "z"))))); - mapType = new MapType(DOUBLE, new MapType(VARCHAR, VARCHAR)); + mapType = mapType(DOUBLE, mapType(VARCHAR, VARCHAR)); aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); assertAggregation( @@ -147,7 +148,7 @@ public void testStructural() mapType, mapBlockOf( DOUBLE, - new MapType(VARCHAR, VARCHAR), + mapType(VARCHAR, VARCHAR), ImmutableMap.of( 1.0, ImmutableMap.of("a", "b"), @@ -155,12 +156,12 @@ public void testStructural() ImmutableMap.of("c", "d"))), mapBlockOf( DOUBLE, - new MapType(VARCHAR, VARCHAR), + mapType(VARCHAR, VARCHAR), ImmutableMap.of( 3.0, ImmutableMap.of("e", "f"))))); - mapType = new MapType(new ArrayType(VARCHAR), DOUBLE); + mapType = mapType(new ArrayType(VARCHAR), DOUBLE); aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); assertAggregation( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java index 7603ecda45de2..0a2595aa90209 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMinMaxByAggregation.java @@ -17,10 +17,10 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.state.StateCompiler; import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java index afcfebbed5e6f..1a2cc3284f301 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java @@ -16,10 +16,10 @@ import com.facebook.presto.RowPageBuilder; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.Signature; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -37,6 +37,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Preconditions.checkState; public class TestMultimapAggAggregation @@ -87,7 +88,7 @@ public void testNullMap() public void testDoubleMapMultimap() throws Exception { - Type mapType = new MapType(VARCHAR, BIGINT); + Type mapType = mapType(VARCHAR, BIGINT); List expectedKeys = ImmutableList.of(1.0, 2.0, 3.0); List> expectedValues = ImmutableList.of(ImmutableMap.of("a", 1L), ImmutableMap.of("b", 2L, "c", 3L, "d", 4L), ImmutableMap.of("a", 1L)); @@ -117,7 +118,7 @@ private static void testMultimapAgg(Type keyType, List expectedKeys, T { checkState(expectedKeys.size() == expectedValues.size(), "expectedKeys and expectedValues should have equal size"); - MapType mapType = new MapType(keyType, new ArrayType(valueType)); + MapType mapType = mapType(keyType, new ArrayType(valueType)); Signature signature = new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), keyType.getTypeSignature(), valueType.getTypeSignature()); InternalAggregationFunction aggFunc = metadata.getFunctionRegistry().getAggregateFunctionImplementation(signature); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java index f234fc00c4a06..9429b256958ff 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java @@ -20,9 +20,9 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; @@ -39,6 +39,7 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -127,7 +128,7 @@ public void testBadNumberOfBuckets() private static Map extractSingleValue(Block block) throws IOException { - MapType mapType = new MapType(REAL, REAL); + MapType mapType = mapType(REAL, REAL); return (Map) mapType.getObjectValue(null, block, 0); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java index 37963b9778c0d..0ace09355d23d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestStateCompiler.java @@ -13,6 +13,13 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.array.BlockBigArray; +import com.facebook.presto.array.BooleanBigArray; +import com.facebook.presto.array.ByteBigArray; +import com.facebook.presto.array.DoubleBigArray; +import com.facebook.presto.array.LongBigArray; +import com.facebook.presto.array.ReferenceCountMap; +import com.facebook.presto.array.SliceBigArray; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.operator.aggregation.state.LongState; import com.facebook.presto.operator.aggregation.state.NullableLongState; @@ -26,16 +33,18 @@ import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateSerializer; import com.facebook.presto.spi.function.GroupedAccumulatorState; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.util.Reflection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; +import java.lang.invoke.MethodHandle; +import java.lang.reflect.Field; import java.util.Map; import java.util.Optional; @@ -46,6 +55,7 @@ import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedDoubleArray; import static org.testng.Assert.assertEquals; @@ -193,7 +203,7 @@ public void testVarianceStateSerialization() public void testComplexSerialization() { Type arrayType = new ArrayType(BIGINT); - Type mapType = new MapType(BIGINT, VARCHAR); + Type mapType = mapType(BIGINT, VARCHAR); Map fieldMap = ImmutableMap.of("Block", arrayType, "AnotherBlock", mapType); AccumulatorStateFactory factory = StateCompiler.generateStateFactory(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader())); AccumulatorStateSerializer serializer = StateCompiler.generateStateSerializer(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader())); @@ -241,14 +251,68 @@ private long getSize(Slice slice) return slice.length() + SLICE_INSTANCE_SIZE; } + private long getComplexStateRetainedSize(TestComplexState state) + { + long retainedSize = ClassLayout.parseClass(state.getClass()).instanceSize(); + // reflection is necessary because TestComplexState implementation is generated + Field[] fields = state.getClass().getDeclaredFields(); + try { + for (Field field : fields) { + Class type = field.getType(); + field.setAccessible(true); + if (type == BlockBigArray.class || type == BooleanBigArray.class || type == SliceBigArray.class || + type == ByteBigArray.class || type == DoubleBigArray.class || type == LongBigArray.class) { + MethodHandle sizeOf = Reflection.methodHandle(type, "sizeOf", null); + retainedSize += (long) sizeOf.invokeWithArguments(field.get(state)); + } + } + } + catch (Throwable t) { + throw new RuntimeException(t); + } + return retainedSize; + } + + private static long getBlockBigArrayReferenceCountMapOverhead(TestComplexState state) + { + long overhead = 0; + // reflection is necessary because TestComplexState implementation is generated + Field[] stateFields = state.getClass().getDeclaredFields(); + try { + for (Field stateField : stateFields) { + if (stateField.getType() != BlockBigArray.class) { + continue; + } + stateField.setAccessible(true); + Field[] blockBigArrayFields = stateField.getType().getDeclaredFields(); + for (Field blockBigArrayField : blockBigArrayFields) { + if (blockBigArrayField.getType() != ReferenceCountMap.class) { + continue; + } + blockBigArrayField.setAccessible(true); + MethodHandle sizeOf = Reflection.methodHandle(blockBigArrayField.getType(), "sizeOf", null); + overhead += (long) sizeOf.invokeWithArguments(blockBigArrayField.get(stateField.get(state))); + } + } + } + catch (Throwable t) { + throw new RuntimeException(t); + } + return overhead; + } + @Test public void testComplexStateEstimatedSize() { - Map fieldMap = ImmutableMap.of("Block", new ArrayType(BIGINT), "AnotherBlock", new MapType(BIGINT, VARCHAR)); + Map fieldMap = ImmutableMap.of("Block", new ArrayType(BIGINT), "AnotherBlock", mapType(BIGINT, VARCHAR)); AccumulatorStateFactory factory = StateCompiler.generateStateFactory(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader())); TestComplexState groupedState = factory.createGroupedState(); - assertEquals(groupedState.getEstimatedSize(), 76064); + long initialRetainedSize = getComplexStateRetainedSize(groupedState); + assertEquals(groupedState.getEstimatedSize(), initialRetainedSize); + // BlockBigArray has an internal map that can grow in size when getting more blocks + // need to handle the map overhead separately + initialRetainedSize -= getBlockBigArrayReferenceCountMapOverhead(groupedState); for (int i = 0; i < 1000; i++) { long retainedSize = 0; ((GroupedAccumulatorState) groupedState).setGroupId(i); @@ -272,7 +336,7 @@ public void testComplexStateEstimatedSize() Block map = mapBlockBuilder.build(); retainedSize += map.getRetainedSizeInBytes(); groupedState.setAnotherBlock(map); - assertEquals(groupedState.getEstimatedSize(), 76064 + retainedSize * (i + 1)); + assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * (i + 1) + getBlockBigArrayReferenceCountMapOverhead(groupedState)); } for (int i = 0; i < 1000; i++) { @@ -298,7 +362,7 @@ public void testComplexStateEstimatedSize() Block map = mapBlockBuilder.build(); retainedSize += map.getRetainedSizeInBytes(); groupedState.setAnotherBlock(map); - assertEquals(groupedState.getEstimatedSize(), 76064 + retainedSize * 1000); + assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * 1000 + getBlockBigArrayReferenceCountMapOverhead(groupedState)); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java index 473943585e83b..ad98cfdee0a17 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedHistogram.java @@ -16,12 +16,14 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.MapType; import org.testng.annotations.Test; import java.util.function.IntUnaryOperator; import java.util.stream.IntStream; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static org.testng.Assert.assertEquals; public class TestTypedHistogram @@ -42,7 +44,10 @@ public void testMassive() typedHistogram.add(i, inputBlock, 1); } - Block outputBlock = typedHistogram.serialize(); + MapType mapType = mapType(BIGINT, BIGINT); + BlockBuilder out = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); + typedHistogram.serialize(out); + Block outputBlock = mapType.getObject(out, 0); for (int i = 0; i < outputBlock.getPositionCount(); i += 2) { assertEquals(BIGINT.getLong(outputBlock, i + 1), BIGINT.getLong(outputBlock, i)); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java b/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java index 8c00c58fc55dd..1e30485be3597 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/index/TestFieldSetFilteringRecordSet.java @@ -17,9 +17,9 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.InMemoryRecordSet; import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java index ae4a37e4ec3c3..7c2ae8ddb32ba 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageFilter.java @@ -32,6 +32,7 @@ import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; public class TestDictionaryAwarePageFilter @@ -70,6 +71,15 @@ private static void testRleBlock(boolean filterRange) testFilter(filter, noMatch, filterRange); } + @Test + public void testRleBlockWithFailure() + throws Exception + { + DictionaryAwarePageFilter filter = createDictionaryAwarePageFilter(true, LongArrayBlock.class); + RunLengthEncodedBlock fail = new RunLengthEncodedBlock(createLongSequenceBlock(-10, -9), 100); + assertThrows(NegativeValueException.class, () -> testFilter(filter, fail, true)); + } + @Test public void testDictionaryBlock() throws Exception @@ -81,7 +91,14 @@ public void testDictionaryBlock() testFilter(createDictionaryBlock(20, 0), LongArrayBlock.class); // match all - testFilter(new DictionaryBlock(100, createLongSequenceBlock(4, 5), new int[100]), LongArrayBlock.class); + testFilter(new DictionaryBlock(createLongSequenceBlock(4, 5), new int[100]), LongArrayBlock.class); + } + + @Test + public void testDictionaryBlockWithFailure() + throws Exception + { + assertThrows(NegativeValueException.class, () -> testFilter(createDictionaryBlockWithFailure(20, 100), LongArrayBlock.class)); } @Test @@ -95,7 +112,7 @@ public void testDictionaryBlockProcessingWithUnusedFailure() testFilter(createDictionaryBlockWithUnusedEntries(20, 0), DictionaryBlock.class); // match all - testFilter(new DictionaryBlock(100, createLongsBlock(4, 5, -1), new int[100]), DictionaryBlock.class); + testFilter(new DictionaryBlock(createLongsBlock(4, 5, -1), new int[100]), DictionaryBlock.class); } @Test @@ -130,7 +147,15 @@ private static DictionaryBlock createDictionaryBlock(int dictionarySize, int blo Block dictionary = createLongSequenceBlock(0, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> index % dictionarySize); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); + } + + private static DictionaryBlock createDictionaryBlockWithFailure(int dictionarySize, int blockSize) + { + Block dictionary = createLongSequenceBlock(-10, dictionarySize - 10); + int[] ids = new int[blockSize]; + Arrays.setAll(ids, index -> index % dictionarySize); + return new DictionaryBlock(dictionary, ids); } private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictionarySize, int blockSize) @@ -138,7 +163,7 @@ private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictio Block dictionary = createLongSequenceBlock(-10, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> (index % dictionarySize) + 10); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); } private static void testFilter(Block block, Class expectedType) @@ -257,6 +282,7 @@ public SelectedPositions filter(ConnectorSession session, Page page) IntArrayList selectedPositions = new IntArrayList(); for (int position = 0; position < block.getPositionCount(); position++) { long value = block.getLong(position, 0); + verifyPositive(value); boolean selected = isSelected(filterRange, value); if (selected) { @@ -286,5 +312,22 @@ public SelectedPositions filter(ConnectorSession session, Page page) return SelectedPositions.positionsList(selectedPositions.elements(), 3, selectedPositions.size() - 6); } + + private static long verifyPositive(long value) + { + if (value < 0) { + throw new NegativeValueException(value); + } + return value; + } + } + + private static class NegativeValueException + extends RuntimeException + { + public NegativeValueException(long value) + { + super("value is negative: " + value); + } } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java index 66342b525c7c4..7d20c3f82ca0d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/project/TestDictionaryAwarePageProjection.java @@ -35,6 +35,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static io.airlift.testing.Assertions.assertInstanceOf; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; public class TestDictionaryAwarePageProjection { @@ -66,6 +67,16 @@ public void testRleBlock() testProject(block, RunLengthEncodedBlock.class); } + @Test + public void testRleBlockWithFailure() + throws Exception + { + Block value = createLongSequenceBlock(-43, -42); + RunLengthEncodedBlock block = new RunLengthEncodedBlock(value, 100); + + testProjectFails(block, RunLengthEncodedBlock.class); + } + @Test public void testDictionaryBlock() throws Exception @@ -75,6 +86,15 @@ public void testDictionaryBlock() testProject(block, DictionaryBlock.class); } + @Test + public void testDictionaryBlockWithFailure() + throws Exception + { + DictionaryBlock block = createDictionaryBlockWithFailure(10, 100); + + testProjectFails(block, DictionaryBlock.class); + } + @Test public void testDictionaryBlockProcessingWithUnusedFailure() throws Exception @@ -115,7 +135,15 @@ private static DictionaryBlock createDictionaryBlock(int dictionarySize, int blo Block dictionary = createLongSequenceBlock(0, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> index % dictionarySize); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); + } + + private static DictionaryBlock createDictionaryBlockWithFailure(int dictionarySize, int blockSize) + { + Block dictionary = createLongSequenceBlock(-10, dictionarySize - 10); + int[] ids = new int[blockSize]; + Arrays.setAll(ids, index -> index % dictionarySize); + return new DictionaryBlock(dictionary, ids); } private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictionarySize, int blockSize) @@ -123,7 +151,7 @@ private static DictionaryBlock createDictionaryBlockWithUnusedEntries(int dictio Block dictionary = createLongSequenceBlock(-10, dictionarySize); int[] ids = new int[blockSize]; Arrays.setAll(ids, index -> (index % dictionarySize) + 10); - return new DictionaryBlock(blockSize, dictionary, ids); + return new DictionaryBlock(dictionary, ids); } private static void testProject(Block block, Class expectedResultType) @@ -134,6 +162,14 @@ private static void testProject(Block block, Class expectedResu testProjectList(lazyWrapper(block), expectedResultType, createProjection()); } + private static void testProjectFails(Block block, Class expectedResultType) + { + assertThrows(NegativeValueException.class, () -> testProjectRange(block, expectedResultType, createProjection())); + assertThrows(NegativeValueException.class, () -> testProjectList(block, expectedResultType, createProjection())); + assertThrows(NegativeValueException.class, () -> testProjectRange(lazyWrapper(block), expectedResultType, createProjection())); + assertThrows(NegativeValueException.class, () -> testProjectList(lazyWrapper(block), expectedResultType, createProjection())); + } + private static void testProjectRange(Block block, Class expectedResultType, DictionaryAwarePageProjection projection) { Block result = projection.project(null, new Page(block), SelectedPositions.positionsRange(5, 10)); @@ -212,9 +248,18 @@ public Block project(ConnectorSession session, Page page, SelectedPositions sele private static long verifyPositive(long value) { if (value < 0) { - throw new IllegalArgumentException("value is negative: " + value); + throw new NegativeValueException(value); } return value; } } + + private static class NegativeValueException + extends RuntimeException + { + public NegativeValueException(long value) + { + super("value is negative: " + value); + } + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java index 01a78e2dfe57f..223986fee4c80 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayDistinct.java @@ -24,11 +24,11 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java index ea46c75a949ad..e2d50932da7ec 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayFilter.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.gen.ExpressionCompiler; @@ -32,7 +33,6 @@ import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Throwables; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java index 5a0bb6fb1adff..0e8948c96480d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayHashCodeOperator.java @@ -28,12 +28,12 @@ import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; @@ -62,13 +62,13 @@ import static com.facebook.presto.operator.scalar.CombineHashFunction.getHash; import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; +import static com.facebook.presto.spi.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; -import static com.facebook.presto.type.ArrayType.ARRAY_NULL_ELEMENT_MSG; import static com.facebook.presto.type.TypeUtils.checkElementNotNull; import static com.facebook.presto.type.TypeUtils.hashPosition; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java index be08057dfc49c..227884d58385d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayJoin.java @@ -21,10 +21,10 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java index 52d0512b3f821..97540fbab4453 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySort.java @@ -23,11 +23,11 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java index 0a7b2e96e8c74..4bcaf97369448 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArraySubscript.java @@ -24,11 +24,11 @@ import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.openjdk.jmh.annotations.Benchmark; @@ -203,7 +203,7 @@ private static Block createDictionaryValueBlock(int positionCount, int mapSize) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = ThreadLocalRandom.current().nextInt(0, dictionarySize); } - return new DictionaryBlock(positionCount * mapSize, dictionaryBlock, keyIds); + return new DictionaryBlock(dictionaryBlock, keyIds); } private static String randomString(int length) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java index ea6f4bc81a9c6..d45baad21c907 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkArrayTransform.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; @@ -30,7 +31,6 @@ import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import org.openjdk.jmh.annotations.Benchmark; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java index c59b1f6a8daba..15de523800bdf 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapConcat.java @@ -25,10 +25,10 @@ import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.InterleavedBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.openjdk.jmh.annotations.Benchmark; @@ -58,6 +58,7 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static io.airlift.slice.Slices.utf8Slice; @SuppressWarnings("MethodMayBeStatic") @@ -120,7 +121,7 @@ public void setup() throw new UnsupportedOperationException(); } - MapType mapType = new MapType(createUnboundedVarcharType(), DOUBLE); + MapType mapType = mapType(createUnboundedVarcharType(), DOUBLE); Block leftKeyBlock = createKeyBlock(POSITIONS, leftKeys); Block leftValueBlock = createValueBlock(POSITIONS, leftKeys.size()); @@ -160,7 +161,7 @@ public Page getPage() private static Block createMapBlock(int positionCount, Block keyBlock, Block valueBlock) { - InterleavedBlock interleavedBlock = new InterleavedBlock(new Block[]{keyBlock, valueBlock}); + InterleavedBlock interleavedBlock = new InterleavedBlock(new Block[] {keyBlock, valueBlock}); int[] offsets = new int[positionCount + 1]; int mapSize = keyBlock.getPositionCount() / positionCount; for (int i = 0; i < offsets.length; i++) { @@ -176,7 +177,7 @@ private static Block createKeyBlock(int positionCount, List keys) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = i % keys.size(); } - return new DictionaryBlock(positionCount * keys.size(), keyDictionaryBlock, keyIds); + return new DictionaryBlock(keyDictionaryBlock, keyIds); } private static Block createValueBlock(int positionCount, int mapSize) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java index 66da9022ad8ce..63a128621fb83 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkMapSubscript.java @@ -25,11 +25,11 @@ import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.InterleavedBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.openjdk.jmh.annotations.Benchmark; @@ -62,6 +62,7 @@ import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Verify.verify; import static io.airlift.slice.Slices.utf8Slice; @@ -121,15 +122,15 @@ public void setup() Block valueBlock; switch (name) { case "fix-width": - mapType = new MapType(createUnboundedVarcharType(), DOUBLE); + mapType = mapType(createUnboundedVarcharType(), DOUBLE); valueBlock = createFixWidthValueBlock(POSITIONS, mapSize); break; case "var-width": - mapType = new MapType(createUnboundedVarcharType(), createUnboundedVarcharType()); + mapType = mapType(createUnboundedVarcharType(), createUnboundedVarcharType()); valueBlock = createVarWidthValueBlock(POSITIONS, mapSize); break; case "dictionary": - mapType = new MapType(createUnboundedVarcharType(), createUnboundedVarcharType()); + mapType = mapType(createUnboundedVarcharType(), createUnboundedVarcharType()); valueBlock = createDictionaryValueBlock(POSITIONS, mapSize); break; default: @@ -187,7 +188,7 @@ private static Block createKeyBlock(int positionCount, List keys) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = i % keys.size(); } - return new DictionaryBlock(positionCount * keys.size(), keyDictionaryBlock, keyIds); + return new DictionaryBlock(keyDictionaryBlock, keyIds); } private static Block createFixWidthValueBlock(int positionCount, int mapSize) @@ -226,7 +227,7 @@ private static Block createDictionaryValueBlock(int positionCount, int mapSize) for (int i = 0; i < keyIds.length; i++) { keyIds[i] = ThreadLocalRandom.current().nextInt(0, dictionarySize); } - return new DictionaryBlock(positionCount * mapSize, dictionaryBlock, keyIds); + return new DictionaryBlock(dictionaryBlock, keyIds); } private static String randomString(int length) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java index dfcd12b4fed82..7c126d87c2ea4 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformKey.java @@ -23,12 +23,12 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -61,6 +61,7 @@ import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static java.lang.String.format; @SuppressWarnings("MethodMayBeStatic") @@ -114,7 +115,7 @@ public void setup() default: throw new UnsupportedOperationException(); } - MapType mapType = new MapType(elementType, elementType); + MapType mapType = mapType(elementType, elementType); Signature signature = new Signature( name, FunctionKind.SCALAR, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java index b1cfbae3ec99f..cc4364570e800 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/BenchmarkTransformValue.java @@ -23,12 +23,12 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.VariableReferenceExpression; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; @@ -64,6 +64,7 @@ import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static java.lang.String.format; @SuppressWarnings("MethodMayBeStatic") @@ -121,8 +122,8 @@ public void setup() default: throw new UnsupportedOperationException(); } - MapType mapType = new MapType(elementType, elementType); - MapType returnType = new MapType(elementType, BOOLEAN); + MapType mapType = mapType(elementType, elementType); + MapType returnType = mapType(elementType, BOOLEAN); Signature signature = new Signature( name, FunctionKind.SCALAR, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/ConstructorWithInvalidTypeParameters.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/ConstructorWithInvalidTypeParameters.java new file mode 100644 index 0000000000000..dcc30ba7154ff --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/ConstructorWithInvalidTypeParameters.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; + +public final class ConstructorWithInvalidTypeParameters +{ + @TypeParameter("K") + @TypeParameter("V") + public ConstructorWithInvalidTypeParameters(@TypeParameter("K(varchar)") Type type) {} + + @ScalarFunction + @TypeParameter("K") + @TypeParameter("V") + @SqlType(StandardTypes.BIGINT) + public long good1( + @TypeParameter("MAP(K,V)") Type type, + @SqlType(StandardTypes.BIGINT) long value) + { + return value; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/ConstructorWithValidTypeParameters.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/ConstructorWithValidTypeParameters.java new file mode 100644 index 0000000000000..f98c47ed75be0 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/ConstructorWithValidTypeParameters.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; + +public final class ConstructorWithValidTypeParameters +{ + @TypeParameter("K") + @TypeParameter("V") + public ConstructorWithValidTypeParameters(@TypeParameter("MAP(K,V)") Type type) {} + + @ScalarFunction + @TypeParameter("K") + @TypeParameter("V") + @SqlType(StandardTypes.BIGINT) + public long good1( + @TypeParameter("MAP(K,V)") Type type, + @SqlType(StandardTypes.BIGINT) long value) + { + return value; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index fd74199ba579e..2da00738c010d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -59,11 +59,11 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.TestingTransactionHandle; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -327,7 +327,7 @@ private List executeProjectionWithAll(String projection, Type expectedTy private RowExpression toRowExpression(Expression projectionExpression) { Expression translatedProjection = new SymbolToInputRewriter(INPUT_MAPPING).rewrite(projectionExpression); - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput( + Map, Type> expressionTypes = getExpressionTypesFromInput( TEST_SESSION, metadata, SQL_PARSER, @@ -640,7 +640,7 @@ private static SourceOperatorFactory compileScanFilterProject(Optional expressionTypes) + private RowExpression toRowExpression(Expression projection, Map, Type> expressionTypes) { return translate(projection, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java index eb752d712fc5f..8fca03b64875d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java index b853532e9bda3..299c35cad7309 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFilterFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java index 9c3dd47bc1f47..86dd9cbc89499 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayReduceFunction.java @@ -13,13 +13,14 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; +import com.facebook.presto.spi.type.ArrayType; import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.TimeZoneKey.getTimeZoneKey; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.util.Arrays.asList; @@ -94,4 +95,12 @@ public void testInstanceFunction() new ArrayType(INTEGER), asList(1, 2, 3, 4, 5, null, 7)); } + + @Test + public void testCoercion() + { + assertFunction("reduce(ARRAY [123456789012345, NULL, 54321], 0, (s, x) -> s + coalesce(x, 0), s -> s)", BIGINT, 123456789066666L); + // TODO: Support coercion of return type of lambda + assertInvalidFunction("reduce(ARRAY [1, NULL, 2], 0, (s, x) -> CAST (s + x AS TINYINT), s -> s)", FUNCTION_NOT_FOUND); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java index 0b2d71f4a0dff..19801f9bc97d8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayTransformFunction.java @@ -13,9 +13,8 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -31,6 +30,7 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static java.util.Arrays.asList; public class TestArrayTransformFunction @@ -45,6 +45,7 @@ public TestArrayTransformFunction() public void testBasic() throws Exception { + assertFunction("transform(ARRAY [5, 6], x -> 9)", new ArrayType(INTEGER), ImmutableList.of(9, 9)); assertFunction("transform(ARRAY [5, 6], x -> x + 1)", new ArrayType(INTEGER), ImmutableList.of(6, 7)); assertFunction("transform(ARRAY [5 + RANDOM(1), 6], x -> x + 1)", new ArrayType(INTEGER), ImmutableList.of(6, 7)); } @@ -92,7 +93,7 @@ public void testTypeCombinations() assertFunction("transform(ARRAY [25.6, 27.3], x -> CAST(x AS VARCHAR))", new ArrayType(createUnboundedVarcharType()), ImmutableList.of("25.6", "27.3")); assertFunction( "transform(ARRAY [25.6, 27.3], x -> MAP(ARRAY[x + 1], ARRAY[true]))", - new ArrayType(new MapType(DOUBLE, BOOLEAN)), + new ArrayType(mapType(DOUBLE, BOOLEAN)), ImmutableList.of(ImmutableMap.of(26.6, true), ImmutableMap.of(28.3, true))); assertFunction("transform(ARRAY [true, false], x -> if(x, 25, 26))", new ArrayType(INTEGER), ImmutableList.of(25, 26)); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java index 5268845645aa0..fda545db4d8b2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.type.TimestampType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.testing.TestingConnectorSession; +import com.facebook.presto.type.SqlIntervalDayTime; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.joda.time.DateTime; @@ -49,6 +50,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; import static java.lang.Math.toIntExact; import static java.lang.String.format; @@ -911,6 +913,44 @@ public void testDateTimeOutputString() assertFunctionString("timestamp '2333-02-23 23:59:59.999 Asia/Tokyo'", TIMESTAMP_WITH_TIME_ZONE, "2333-02-23 23:59:59.999 Asia/Tokyo"); } + @Test + public void testParseDuration() + { + assertFunction("parse_duration('1234 ns')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 0)); + assertFunction("parse_duration('1234 us')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 1)); + assertFunction("parse_duration('1234 ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1, 234)); + assertFunction("parse_duration('1234 s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 20, 34, 0)); + assertFunction("parse_duration('1234 m')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 20, 34, 0, 0)); + assertFunction("parse_duration('1234 h')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(51, 10, 0, 0, 0)); + assertFunction("parse_duration('1234 d')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(1234, 0, 0, 0, 0)); + assertFunction("parse_duration('1234.567 ns')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 0)); + assertFunction("parse_duration('1234.567 ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1, 235)); + assertFunction("parse_duration('1234.567 s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1234, 567)); + assertFunction("parse_duration('1234.567 m')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 20, 34, 34, 20)); + assertFunction("parse_duration('1234.567 h')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(51, 10, 34, 1, 200)); + assertFunction("parse_duration('1234.567 d')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(1234, 13, 36, 28, 800)); + + // without space + assertFunction("parse_duration('1234ns')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 0)); + assertFunction("parse_duration('1234us')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 1)); + assertFunction("parse_duration('1234ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1, 234)); + assertFunction("parse_duration('1234s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 20, 34, 0)); + assertFunction("parse_duration('1234m')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 20, 34, 0, 0)); + assertFunction("parse_duration('1234h')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(51, 10, 0, 0, 0)); + assertFunction("parse_duration('1234d')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(1234, 0, 0, 0, 0)); + assertFunction("parse_duration('1234.567ns')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 0)); + assertFunction("parse_duration('1234.567ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1, 235)); + assertFunction("parse_duration('1234.567s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1234, 567)); + assertFunction("parse_duration('1234.567m')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 20, 34, 34, 20)); + assertFunction("parse_duration('1234.567h')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(51, 10, 34, 1, 200)); + assertFunction("parse_duration('1234.567d')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(1234, 13, 36, 28, 800)); + + // invalid function calls + assertInvalidFunction("parse_duration('')", "duration is empty"); + assertInvalidFunction("parse_duration('1f')", "Unknown time unit: f"); + assertInvalidFunction("parse_duration('abc')", "duration is not a valid data duration string: abc"); + } + private void assertFunctionString(String projection, Type expectedType, String expected) { functionAssertions.assertFunctionString(projection, expectedType, expected); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestGroupingOperationFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestGroupingOperationFunction.java new file mode 100644 index 0000000000000..5479cdcc3d407 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestGroupingOperationFunction.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.bigintGrouping; +import static com.facebook.presto.operator.scalar.GroupingOperationFunction.integerGrouping; +import static org.testng.Assert.assertEquals; + +public class TestGroupingOperationFunction +{ + private static final List fortyIntegers = ImmutableList.of( + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 + ); + + @Test + public void testGroupingOperationAllBitsSet() + { + List groupingOrdinals = ImmutableList.of(0, 4, 8); + List> groupingSetOrdinals = ImmutableList.of(ImmutableList.of(1), ImmutableList.of(7, 3, 1), ImmutableList.of(9, 1)); + + for (int groupId = 0; groupId < groupingSetOrdinals.size(); groupId++) { + assertEquals(bigintGrouping(groupId, groupingOrdinals, groupingSetOrdinals), 7L); + assertEquals(integerGrouping(groupId, groupingOrdinals, groupingSetOrdinals), 7L); + } + } + + @Test + public void testGroupingOperationNoBitsSet() + { + List groupingOrdinals = ImmutableList.of(4, 6); + List> groupingSetOrdinals = ImmutableList.of(ImmutableList.of(4, 6)); + + for (int groupId = 0; groupId < groupingSetOrdinals.size(); groupId++) { + assertEquals(bigintGrouping(groupId, groupingOrdinals, groupingSetOrdinals), 0L); + assertEquals(integerGrouping(groupId, groupingOrdinals, groupingSetOrdinals), 0L); + } + } + + @Test + public void testGroupingOperationSomeBitsSet() + { + List groupingOrdinals = ImmutableList.of(7, 2, 9, 3, 5); + List> groupingSetOrdinals = ImmutableList.of(ImmutableList.of(4, 2), ImmutableList.of(9, 7, 14), ImmutableList.of(5, 2, 7), ImmutableList.of(3)); + List expectedResults = ImmutableList.of(23L, 11L, 6L, 29L); + + for (int groupId = 0; groupId < groupingSetOrdinals.size(); groupId++) { + assertEquals(Long.valueOf(bigintGrouping(groupId, groupingOrdinals, groupingSetOrdinals)), expectedResults.get(groupId)); + assertEquals(Long.valueOf(integerGrouping(groupId, groupingOrdinals, groupingSetOrdinals)), expectedResults.get(groupId)); + } + } + + @Test + public void testMoreThanThirtyTwoArguments() + { + List> groupingSetOrdinals = ImmutableList.of(ImmutableList.of(20, 2, 13, 33, 40, 9, 14), ImmutableList.of(28, 4, 5, 29, 31, 10)); + List expectedResults = ImmutableList.of(822283861886L, 995358664191L); + + for (int groupId = 0; groupId < groupingSetOrdinals.size(); groupId++) { + assertEquals(Long.valueOf(bigintGrouping(groupId, fortyIntegers, groupingSetOrdinals)), expectedResults.get(groupId)); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java index d08e6f6500f55..0223ef04ee040 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestLambdaExpression.java @@ -14,9 +14,8 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.Session; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.BeforeClass; @@ -36,6 +35,7 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestLambdaExpression extends AbstractTestFunctions @@ -127,8 +127,8 @@ public void testBind() { assertFunction("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", INTEGER, 99); assertFunction("invoke(\"$internal$bind\"(8, x -> x + 1))", INTEGER, 9); - assertFunction("apply(900, \"$internal$bind\"(90, \"$internal$bind\"(9, (x, y, z) -> x + y + z)))", INTEGER, 999); - assertFunction("invoke(\"$internal$bind\"(90, \"$internal$bind\"(9, (x, y) -> x + y)))", INTEGER, 99); + assertFunction("apply(900, \"$internal$bind\"(90, 9, (x, y, z) -> x + y + z))", INTEGER, 999); + assertFunction("invoke(\"$internal$bind\"(90, 9, (x, y) -> x + y))", INTEGER, 99); } @Test @@ -155,7 +155,7 @@ public void testTypeCombinations() assertFunction("apply(25.6, x -> x + 1.0)", DOUBLE, 26.6); assertFunction("apply(25.6, x -> x = 25.6)", BOOLEAN, true); assertFunction("apply(25.6, x -> CAST(x AS VARCHAR))", createUnboundedVarcharType(), "25.6"); - assertFunction("apply(25.6, x -> MAP(ARRAY[x + 1], ARRAY[true]))", new MapType(DOUBLE, BOOLEAN), ImmutableMap.of(26.6, true)); + assertFunction("apply(25.6, x -> MAP(ARRAY[x + 1], ARRAY[true]))", mapType(DOUBLE, BOOLEAN), ImmutableMap.of(26.6, true)); assertFunction("apply(true, x -> if(x, 25, 26))", INTEGER, 25); assertFunction("apply(false, x -> if(x, 25.6, 28.9))", DOUBLE, 28.9); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestListLiteralCast.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestListLiteralCast.java new file mode 100644 index 0000000000000..d3a9a31bdaecf --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestListLiteralCast.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.type.TypeRegistry; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.operator.scalar.ListLiteralCast.castArrayOfArraysToListLiteral; +import static com.facebook.presto.operator.scalar.ListLiteralCast.castArrayToListLiteral; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static org.testng.Assert.assertEquals; + +public class TestListLiteralCast +{ + @Test + public void testArrayToListCast() + { + List expectedValues = ImmutableList.of(1, 2, 3, 4); + + assertEquals(expectedValues, castArrayToListLiteral(buildIntArrayBlockWithIntValues(expectedValues))); + } + + @Test + public void testArrayOfArraysToListLiteral() + { + List> expectedValues = ImmutableList.of( + ImmutableList.of(-1, 726, -44, 0), + ImmutableList.of(4), + ImmutableList.of(7911, 30076, 432, 111)); + + Type integerArrayType = new TypeRegistry().getType(parseTypeSignature("array(integer)")); + BlockBuilder blockBuilder = integerArrayType.createBlockBuilder(new BlockBuilderStatus(), 3); + for (List values : expectedValues) { + blockBuilder.writeObject(buildIntArrayBlockWithIntValues(values)).closeEntry(); + } + + assertEquals(expectedValues, castArrayOfArraysToListLiteral(blockBuilder.build())); + } + + private static Block buildIntArrayBlockWithIntValues(List values) + { + BlockBuilder blockBuilder = INTEGER.createBlockBuilder(new BlockBuilderStatus(), values.size()); + for (Integer value : values) { + INTEGER.writeLong(blockBuilder, value); + } + + return blockBuilder.build(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java index 20a46c1cfeb45..54c7f85e5e26b 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapFilterFunction.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -29,6 +28,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestMapFilterFunction extends AbstractTestFunctions @@ -37,10 +37,10 @@ public class TestMapFilterFunction public void testEmpty() throws Exception { - assertFunction("map_filter(map(ARRAY[], ARRAY[]), (k, v) -> true)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("map_filter(map(ARRAY[], ARRAY[]), (k, v) -> false)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("map_filter(map(ARRAY[], ARRAY[]), (k, v) -> CAST (NULL AS BOOLEAN))", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("map_filter(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> true)", new MapType(BIGINT, VARCHAR), ImmutableMap.of()); + assertFunction("map_filter(map(ARRAY[], ARRAY[]), (k, v) -> true)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("map_filter(map(ARRAY[], ARRAY[]), (k, v) -> false)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("map_filter(map(ARRAY[], ARRAY[]), (k, v) -> CAST (NULL AS BOOLEAN))", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("map_filter(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> true)", mapType(BIGINT, VARCHAR), ImmutableMap.of()); } @Test @@ -49,15 +49,15 @@ public void testNull() { Map oneToNullMap = new HashMap<>(); oneToNullMap.put(1, null); - assertFunction("map_filter(map(ARRAY[1], ARRAY [NULL]), (k, v) -> v IS NULL)", new MapType(INTEGER, UNKNOWN), oneToNullMap); - assertFunction("map_filter(map(ARRAY[1], ARRAY [NULL]), (k, v) -> v IS NOT NULL)", new MapType(INTEGER, UNKNOWN), ImmutableMap.of()); - assertFunction("map_filter(map(ARRAY[1], ARRAY [CAST (NULL AS INTEGER)]), (k, v) -> v IS NULL)", new MapType(INTEGER, INTEGER), oneToNullMap); + assertFunction("map_filter(map(ARRAY[1], ARRAY [NULL]), (k, v) -> v IS NULL)", mapType(INTEGER, UNKNOWN), oneToNullMap); + assertFunction("map_filter(map(ARRAY[1], ARRAY [NULL]), (k, v) -> v IS NOT NULL)", mapType(INTEGER, UNKNOWN), ImmutableMap.of()); + assertFunction("map_filter(map(ARRAY[1], ARRAY [CAST (NULL AS INTEGER)]), (k, v) -> v IS NULL)", mapType(INTEGER, INTEGER), oneToNullMap); Map sequenceToNullMap = new HashMap<>(); sequenceToNullMap.put(1, null); sequenceToNullMap.put(2, null); sequenceToNullMap.put(3, null); - assertFunction("map_filter(map(ARRAY[1, 2, 3], ARRAY [NULL, NULL, NULL]), (k, v) -> v IS NULL)", new MapType(INTEGER, UNKNOWN), sequenceToNullMap); - assertFunction("map_filter(map(ARRAY[1, 2, 3], ARRAY [NULL, NULL, NULL]), (k, v) -> v IS NOT NULL)", new MapType(INTEGER, UNKNOWN), ImmutableMap.of()); + assertFunction("map_filter(map(ARRAY[1, 2, 3], ARRAY [NULL, NULL, NULL]), (k, v) -> v IS NULL)", mapType(INTEGER, UNKNOWN), sequenceToNullMap); + assertFunction("map_filter(map(ARRAY[1, 2, 3], ARRAY [NULL, NULL, NULL]), (k, v) -> v IS NOT NULL)", mapType(INTEGER, UNKNOWN), ImmutableMap.of()); } @Test @@ -66,19 +66,19 @@ public void testBasic() { assertFunction( "map_filter(map(ARRAY [5, 6, 7, 8], ARRAY [5, 6, 6, 5]), (x, y) -> x <= 6 OR y = 5)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(5, 5, 6, 6, 8, 5)); assertFunction( "map_filter(map(ARRAY [5 + RANDOM(1), 6, 7, 8], ARRAY [5, 6, 6, 5]), (x, y) -> x <= 6 OR y = 5)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(5, 5, 6, 6, 8, 5)); assertFunction( "map_filter(map(ARRAY ['a', 'b', 'c', 'd'], ARRAY [1, 2, NULL, 4]), (k, v) -> v IS NOT NULL)", - new MapType(createVarcharType(1), INTEGER), + mapType(createVarcharType(1), INTEGER), ImmutableMap.of("a", 1, "b", 2, "d", 4)); assertFunction( "map_filter(map(ARRAY ['a', 'b', 'c'], ARRAY [TRUE, FALSE, NULL]), (k, v) -> v)", - new MapType(createVarcharType(1), BOOLEAN), + mapType(createVarcharType(1), BOOLEAN), ImmutableMap.of("a", true)); } @@ -88,109 +88,109 @@ public void testTypeCombinations() { assertFunction( "map_filter(map(ARRAY [25, 26, 27], ARRAY [25, 26, 27]), (k, v) -> k = 25 OR v = 27)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(25, 25, 27, 27)); assertFunction( "map_filter(map(ARRAY [25, 26, 27], ARRAY [25.5, 26.5, 27.5]), (k, v) -> k = 25 OR v = 27.5)", - new MapType(INTEGER, DOUBLE), + mapType(INTEGER, DOUBLE), ImmutableMap.of(25, 25.5, 27, 27.5)); assertFunction( "map_filter(map(ARRAY [25, 26, 27], ARRAY [false, null, true]), (k, v) -> k = 25 OR v)", - new MapType(INTEGER, BOOLEAN), + mapType(INTEGER, BOOLEAN), ImmutableMap.of(25, false, 27, true)); assertFunction( "map_filter(map(ARRAY [25, 26, 27], ARRAY ['abc', 'def', 'xyz']), (k, v) -> k = 25 OR v = 'xyz')", - new MapType(INTEGER, createVarcharType(3)), + mapType(INTEGER, createVarcharType(3)), ImmutableMap.of(25, "abc", 27, "xyz")); assertFunction( "map_filter(map(ARRAY [25, 26, 27], ARRAY [ARRAY ['a', 'b'], ARRAY ['a', 'c'], ARRAY ['a', 'b', 'c']]), (k, v) -> k = 25 OR cardinality(v) = 3)", - new MapType(INTEGER, new ArrayType(createVarcharType(1))), + mapType(INTEGER, new ArrayType(createVarcharType(1))), ImmutableMap.of(25, ImmutableList.of("a", "b"), 27, ImmutableList.of("a", "b", "c"))); assertFunction( "map_filter(map(ARRAY [25.5, 26.5, 27.5], ARRAY [25, 26, 27]), (k, v) -> k = 25.5 OR v = 27)", - new MapType(DOUBLE, INTEGER), + mapType(DOUBLE, INTEGER), ImmutableMap.of(25.5, 25, 27.5, 27)); assertFunction( "map_filter(map(ARRAY [25.5, 26.5, 27.5], ARRAY [25.5, 26.5, 27.5]), (k, v) -> k = 25.5 OR v = 27.5)", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(25.5, 25.5, 27.5, 27.5)); assertFunction( "map_filter(map(ARRAY [25.5, 26.5, 27.5], ARRAY [false, null, true]), (k, v) -> k = 25.5 OR v)", - new MapType(DOUBLE, BOOLEAN), + mapType(DOUBLE, BOOLEAN), ImmutableMap.of(25.5, false, 27.5, true)); assertFunction( "map_filter(map(ARRAY [25.5, 26.5, 27.5], ARRAY ['abc', 'def', 'xyz']), (k, v) -> k = 25.5 OR v = 'xyz')", - new MapType(DOUBLE, createVarcharType(3)), + mapType(DOUBLE, createVarcharType(3)), ImmutableMap.of(25.5, "abc", 27.5, "xyz")); assertFunction( "map_filter(map(ARRAY [25.5, 26.5, 27.5], ARRAY [ARRAY ['a', 'b'], ARRAY ['a', 'c'], ARRAY ['a', 'b', 'c']]), (k, v) -> k = 25.5 OR cardinality(v) = 3)", - new MapType(DOUBLE, new ArrayType(createVarcharType(1))), + mapType(DOUBLE, new ArrayType(createVarcharType(1))), ImmutableMap.of(25.5, ImmutableList.of("a", "b"), 27.5, ImmutableList.of("a", "b", "c"))); assertFunction( "map_filter(map(ARRAY [true, false], ARRAY [25, 26]), (k, v) -> k AND v = 25)", - new MapType(BOOLEAN, INTEGER), + mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 25)); assertFunction( "map_filter(map(ARRAY [false, true], ARRAY [25.5, 26.5]), (k, v) -> k OR v > 100)", - new MapType(BOOLEAN, DOUBLE), + mapType(BOOLEAN, DOUBLE), ImmutableMap.of(true, 26.5)); Map falseToNullMap = new HashMap<>(); falseToNullMap.put(false, null); assertFunction( "map_filter(map(ARRAY [true, false], ARRAY [false, null]), (k, v) -> NOT k OR v)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), falseToNullMap); assertFunction( "map_filter(map(ARRAY [false, true], ARRAY ['abc', 'def']), (k, v) -> NOT k AND v = 'abc')", - new MapType(BOOLEAN, createVarcharType(3)), + mapType(BOOLEAN, createVarcharType(3)), ImmutableMap.of(false, "abc")); assertFunction( "map_filter(map(ARRAY [true, false], ARRAY [ARRAY ['a', 'b'], ARRAY ['a', 'b', 'c']]), (k, v) -> k OR cardinality(v) = 3)", - new MapType(BOOLEAN, new ArrayType(createVarcharType(1))), + mapType(BOOLEAN, new ArrayType(createVarcharType(1))), ImmutableMap.of(true, ImmutableList.of("a", "b"), false, ImmutableList.of("a", "b", "c"))); assertFunction( "map_filter(map(ARRAY ['s0', 's1', 's2'], ARRAY [25, 26, 27]), (k, v) -> k = 's0' OR v = 27)", - new MapType(createVarcharType(2), INTEGER), + mapType(createVarcharType(2), INTEGER), ImmutableMap.of("s0", 25, "s2", 27)); assertFunction( "map_filter(map(ARRAY ['s0', 's1', 's2'], ARRAY [25.5, 26.5, 27.5]), (k, v) -> k = 's0' OR v = 27.5)", - new MapType(createVarcharType(2), DOUBLE), + mapType(createVarcharType(2), DOUBLE), ImmutableMap.of("s0", 25.5, "s2", 27.5)); assertFunction( "map_filter(map(ARRAY ['s0', 's1', 's2'], ARRAY [false, null, true]), (k, v) -> k = 's0' OR v)", - new MapType(createVarcharType(2), BOOLEAN), + mapType(createVarcharType(2), BOOLEAN), ImmutableMap.of("s0", false, "s2", true)); assertFunction( "map_filter(map(ARRAY ['s0', 's1', 's2'], ARRAY ['abc', 'def', 'xyz']), (k, v) -> k = 's0' OR v = 'xyz')", - new MapType(createVarcharType(2), createVarcharType(3)), + mapType(createVarcharType(2), createVarcharType(3)), ImmutableMap.of("s0", "abc", "s2", "xyz")); assertFunction( "map_filter(map(ARRAY ['s0', 's1', 's2'], ARRAY [ARRAY ['a', 'b'], ARRAY ['a', 'c'], ARRAY ['a', 'b', 'c']]), (k, v) -> k = 's0' OR cardinality(v) = 3)", - new MapType(createVarcharType(2), new ArrayType(createVarcharType(1))), + mapType(createVarcharType(2), new ArrayType(createVarcharType(1))), ImmutableMap.of("s0", ImmutableList.of("a", "b"), "s2", ImmutableList.of("a", "b", "c"))); assertFunction( "map_filter(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4], ARRAY []], ARRAY [25, 26, 27]), (k, v) -> k = ARRAY [1, 2] OR v = 27)", - new MapType(new ArrayType(INTEGER), INTEGER), + mapType(new ArrayType(INTEGER), INTEGER), ImmutableMap.of(ImmutableList.of(1, 2), 25, ImmutableList.of(), 27)); assertFunction( "map_filter(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4], ARRAY []], ARRAY [25.5, 26.5, 27.5]), (k, v) -> k = ARRAY [1, 2] OR v = 27.5)", - new MapType(new ArrayType(INTEGER), DOUBLE), + mapType(new ArrayType(INTEGER), DOUBLE), ImmutableMap.of(ImmutableList.of(1, 2), 25.5, ImmutableList.of(), 27.5)); assertFunction( "map_filter(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4], ARRAY []], ARRAY [false, null, true]), (k, v) -> k = ARRAY [1, 2] OR v)", - new MapType(new ArrayType(INTEGER), BOOLEAN), + mapType(new ArrayType(INTEGER), BOOLEAN), ImmutableMap.of(ImmutableList.of(1, 2), false, ImmutableList.of(), true)); assertFunction( "map_filter(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4], ARRAY []], ARRAY ['abc', 'def', 'xyz']), (k, v) -> k = ARRAY [1, 2] OR v = 'xyz')", - new MapType(new ArrayType(INTEGER), createVarcharType(3)), + mapType(new ArrayType(INTEGER), createVarcharType(3)), ImmutableMap.of(ImmutableList.of(1, 2), "abc", ImmutableList.of(), "xyz")); assertFunction( "map_filter(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4], ARRAY []], ARRAY [ARRAY ['a', 'b'], ARRAY ['a', 'b', 'c'], ARRAY ['a', 'c']]), (k, v) -> cardinality(k) = 0 OR cardinality(v) = 3)", - new MapType(new ArrayType(INTEGER), new ArrayType(createVarcharType(1))), + mapType(new ArrayType(INTEGER), new ArrayType(createVarcharType(1))), ImmutableMap.of(ImmutableList.of(3, 4), ImmutableList.of("a", "b", "c"), ImmutableList.of(), ImmutableList.of("a", "c"))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java index de31487e1f97b..0cc0d621fe238 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformKeyFunction.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -29,6 +28,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestMapTransformKeyFunction extends AbstractTestFunctions @@ -37,15 +37,15 @@ public class TestMapTransformKeyFunction public void testEmpty() throws Exception { - assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> NULL)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> k)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> v)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - - assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> 0)", new MapType(INTEGER, UNKNOWN), ImmutableMap.of()); - assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> true)", new MapType(BOOLEAN, UNKNOWN), ImmutableMap.of()); - assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> 'key')", new MapType(createVarcharType(3), UNKNOWN), ImmutableMap.of()); - assertFunction("transform_keys(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> k + CAST(v as BIGINT))", new MapType(BIGINT, VARCHAR), ImmutableMap.of()); - assertFunction("transform_keys(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> v)", new MapType(VARCHAR, VARCHAR), ImmutableMap.of()); + assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> NULL)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> k)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> v)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + + assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> 0)", mapType(INTEGER, UNKNOWN), ImmutableMap.of()); + assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> true)", mapType(BOOLEAN, UNKNOWN), ImmutableMap.of()); + assertFunction("transform_keys(map(ARRAY[], ARRAY[]), (k, v) -> 'key')", mapType(createVarcharType(3), UNKNOWN), ImmutableMap.of()); + assertFunction("transform_keys(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> k + CAST(v as BIGINT))", mapType(BIGINT, VARCHAR), ImmutableMap.of()); + assertFunction("transform_keys(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> v)", mapType(VARCHAR, VARCHAR), ImmutableMap.of()); } @Test @@ -82,22 +82,22 @@ public void testBasic() { assertFunction( "transform_keys(map(ARRAY [1, 2, 3, 4], ARRAY [10, 20, 30, 40]), (k, v) -> k + v)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(11, 10, 22, 20, 33, 30, 44, 40)); assertFunction( "transform_keys(map(ARRAY ['a', 'b', 'c', 'd'], ARRAY [1, 2, 3, 4]), (k, v) -> v * v)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(1, 1, 4, 2, 9, 3, 16, 4)); assertFunction( "transform_keys(map(ARRAY ['a', 'b', 'c', 'd'], ARRAY [1, 2, 3, 4]), (k, v) -> k || CAST(v as VARCHAR))", - new MapType(VARCHAR, INTEGER), + mapType(VARCHAR, INTEGER), ImmutableMap.of("a1", 1, "b2", 2, "c3", 3, "d4", 4)); assertFunction( "transform_keys(map(ARRAY[1, 2, 3], ARRAY [1.0, 1.4, 1.7]), (k, v) -> map(ARRAY[1, 2, 3], ARRAY['one', 'two', 'three'])[k])", - new MapType(createVarcharType(5), DOUBLE), + mapType(createVarcharType(5), DOUBLE), ImmutableMap.of("one", 1.0, "two", 1.4, "three", 1.7)); Map expectedStringIntMap = new HashMap<>(); @@ -107,7 +107,7 @@ public void testBasic() expectedStringIntMap.put("d4", 4); assertFunction( "transform_keys(map(ARRAY ['a', 'b', 'c', 'd'], ARRAY [1, NULL, 3, 4]), (k, v) -> k || COALESCE(CAST(v as VARCHAR), '0'))", - new MapType(VARCHAR, INTEGER), + mapType(VARCHAR, INTEGER), expectedStringIntMap); } @@ -117,110 +117,110 @@ public void testTypeCombinations() { assertFunction( "transform_keys(map(ARRAY [25, 26, 27], ARRAY [25, 26, 27]), (k, v) -> k + v)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(50, 25, 52, 26, 54, 27)); assertFunction( "transform_keys(map(ARRAY [25, 26, 27], ARRAY [25.5, 26.5, 27.5]), (k, v) -> k + v)", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(50.5, 25.5, 52.5, 26.5, 54.5, 27.5)); assertFunction( "transform_keys(map(ARRAY [25, 26], ARRAY [false, true]), (k, v) -> k % 2 = 0 OR v)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, false, true, true)); assertFunction( "transform_keys(map(ARRAY [25, 26, 27], ARRAY ['abc', 'def', 'xyz']), (k, v) -> to_base(k, 16) || substr(v, 1, 1))", - new MapType(VARCHAR, createVarcharType(3)), + mapType(VARCHAR, createVarcharType(3)), ImmutableMap.of("19a", "abc", "1ad", "def", "1bx", "xyz")); assertFunction( "transform_keys(map(ARRAY [25, 26], ARRAY [ARRAY ['a'], ARRAY ['b']]), (k, v) -> ARRAY [CAST(k AS VARCHAR)] || v)", - new MapType(new ArrayType(VARCHAR), new ArrayType(createVarcharType(1))), + mapType(new ArrayType(VARCHAR), new ArrayType(createVarcharType(1))), ImmutableMap.of(ImmutableList.of("25", "a"), ImmutableList.of("a"), ImmutableList.of("26", "b"), ImmutableList.of("b"))); assertFunction( "transform_keys(map(ARRAY [25.5, 26.5, 27.5], ARRAY [25, 26, 27]), (k, v) -> CAST(k * 2 AS BIGINT) + v)", - new MapType(BIGINT, INTEGER), + mapType(BIGINT, INTEGER), ImmutableMap.of(76L, 25, 79L, 26, 82L, 27)); assertFunction( "transform_keys(map(ARRAY [25.5, 26.5, 27.5], ARRAY [25.5, 26.5, 27.5]), (k, v) -> k + v)", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(51.0, 25.5, 53.0, 26.5, 55.0, 27.5)); assertFunction( "transform_keys(map(ARRAY [25.2, 26.2], ARRAY [false, true]), (k, v) -> CAST(k AS BIGINT) % 2 = 0 OR v)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, false, true, true)); assertFunction( "transform_keys(map(ARRAY [25.5, 26.5, 27.5], ARRAY ['abc', 'def', 'xyz']), (k, v) -> CAST(k AS VARCHAR) || substr(v, 1, 1))", - new MapType(VARCHAR, createVarcharType(3)), + mapType(VARCHAR, createVarcharType(3)), ImmutableMap.of("25.5a", "abc", "26.5d", "def", "27.5x", "xyz")); assertFunction( "transform_keys(map(ARRAY [25.5, 26.5], ARRAY [ARRAY ['a'], ARRAY ['b']]), (k, v) -> ARRAY [CAST(k AS VARCHAR)] || v)", - new MapType(new ArrayType(VARCHAR), new ArrayType(createVarcharType(1))), + mapType(new ArrayType(VARCHAR), new ArrayType(createVarcharType(1))), ImmutableMap.of(ImmutableList.of("25.5", "a"), ImmutableList.of("a"), ImmutableList.of("26.5", "b"), ImmutableList.of("b"))); assertFunction( "transform_keys(map(ARRAY [true, false], ARRAY [25, 26]), (k, v) -> if(k, 2 * v, 3 * v))", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(50, 25, 78, 26)); assertFunction( "transform_keys(map(ARRAY [false, true], ARRAY [25.5, 26.5]), (k, v) -> if(k, 2 * v, 3 * v))", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(76.5, 25.5, 53.0, 26.5)); Map expectedBoolBoolMap = new HashMap<>(); expectedBoolBoolMap.put(false, true); expectedBoolBoolMap.put(true, null); assertFunction( "transform_keys(map(ARRAY [true, false], ARRAY [true, NULL]), (k, v) -> if(k, NOT v, v IS NULL))", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), expectedBoolBoolMap); assertFunction( "transform_keys(map(ARRAY [false, true], ARRAY ['abc', 'def']), (k, v) -> if(k, substr(v, 1, 2), substr(v, 1, 1)))", - new MapType(createVarcharType(3), createVarcharType(3)), + mapType(createVarcharType(3), createVarcharType(3)), ImmutableMap.of("a", "abc", "de", "def")); assertFunction( "transform_keys(map(ARRAY [true, false], ARRAY [ARRAY ['a', 'b'], ARRAY ['x', 'y']]), (k, v) -> if(k, reverse(v), v))", - new MapType(new ArrayType(createVarcharType(1)), new ArrayType(createVarcharType(1))), + mapType(new ArrayType(createVarcharType(1)), new ArrayType(createVarcharType(1))), ImmutableMap.of(ImmutableList.of("b", "a"), ImmutableList.of("a", "b"), ImmutableList.of("x", "y"), ImmutableList.of("x", "y"))); assertFunction( "transform_keys(map(ARRAY ['a', 'ab', 'abc'], ARRAY [25, 26, 27]), (k, v) -> length(k) + v)", - new MapType(BIGINT, INTEGER), + mapType(BIGINT, INTEGER), ImmutableMap.of(26L, 25, 28L, 26, 30L, 27)); assertFunction( "transform_keys(map(ARRAY ['a', 'ab', 'abc'], ARRAY [25.5, 26.5, 27.5]), (k, v) -> length(k) + v)", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(26.5, 25.5, 28.5, 26.5, 30.5, 27.5)); assertFunction( "transform_keys(map(ARRAY ['a', 'b'], ARRAY [false, true]), (k, v) -> k = 'b' OR v)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, false, true, true)); assertFunction( "transform_keys(map(ARRAY ['a', 'x'], ARRAY ['bc', 'yz']), (k, v) -> k || v)", - new MapType(VARCHAR, createVarcharType(2)), + mapType(VARCHAR, createVarcharType(2)), ImmutableMap.of("abc", "bc", "xyz", "yz")); assertFunction( "transform_keys(map(ARRAY ['x', 'y'], ARRAY [ARRAY ['a'], ARRAY ['b']]), (k, v) -> k || v)", - new MapType(new ArrayType(createVarcharType(1)), new ArrayType(createVarcharType(1))), + mapType(new ArrayType(createVarcharType(1)), new ArrayType(createVarcharType(1))), ImmutableMap.of(ImmutableList.of("x", "a"), ImmutableList.of("a"), ImmutableList.of("y", "b"), ImmutableList.of("b"))); assertFunction( "transform_keys(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY [25, 26]), (k, v) -> reduce(k, 0, (s, x) -> s + x, s -> s) + v)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(28, 25, 33, 26)); assertFunction( "transform_keys(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY [25.5, 26.5]), (k, v) -> reduce(k, 0, (s, x) -> s + x, s -> s) + v)", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(28.5, 25.5, 33.5, 26.5)); assertFunction( "transform_keys(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY [false, true]), (k, v) -> contains(k, 3) AND v)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, false, true, true)); assertFunction( "transform_keys(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY ['abc', 'xyz']), (k, v) -> transform(k, x -> CAST(x AS VARCHAR)) || v)", - new MapType(new ArrayType(VARCHAR), createVarcharType(3)), + mapType(new ArrayType(VARCHAR), createVarcharType(3)), ImmutableMap.of(ImmutableList.of("1", "2", "abc"), "abc", ImmutableList.of("3", "4", "xyz"), "xyz")); assertFunction( "transform_keys(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY [ARRAY ['a'], ARRAY ['a', 'b']]), (k, v) -> transform(k, x -> CAST(x AS VARCHAR)) || v)", - new MapType(new ArrayType(VARCHAR), new ArrayType(createVarcharType(1))), + mapType(new ArrayType(VARCHAR), new ArrayType(createVarcharType(1))), ImmutableMap.of(ImmutableList.of("1", "2", "a"), ImmutableList.of("a"), ImmutableList.of("3", "4", "a", "b"), ImmutableList.of("a", "b"))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java index a8524a7af49b0..78cfa2b173851 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMapTransformValueFunction.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.ArrayType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -29,6 +28,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestMapTransformValueFunction extends AbstractTestFunctions @@ -37,15 +37,15 @@ public class TestMapTransformValueFunction public void testEmpty() throws Exception { - assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> NULL)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> k)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> v)", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - - assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> 0)", new MapType(UNKNOWN, INTEGER), ImmutableMap.of()); - assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> true)", new MapType(UNKNOWN, BOOLEAN), ImmutableMap.of()); - assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> 'value')", new MapType(UNKNOWN, createVarcharType(5)), ImmutableMap.of()); - assertFunction("transform_values(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> k + CAST(v as BIGINT))", new MapType(BIGINT, BIGINT), ImmutableMap.of()); - assertFunction("transform_values(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> CAST(k AS VARCHAR) || v)", new MapType(BIGINT, VARCHAR), ImmutableMap.of()); + assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> NULL)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> k)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> v)", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + + assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> 0)", mapType(UNKNOWN, INTEGER), ImmutableMap.of()); + assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> true)", mapType(UNKNOWN, BOOLEAN), ImmutableMap.of()); + assertFunction("transform_values(map(ARRAY[], ARRAY[]), (k, v) -> 'value')", mapType(UNKNOWN, createVarcharType(5)), ImmutableMap.of()); + assertFunction("transform_values(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> k + CAST(v as BIGINT))", mapType(BIGINT, BIGINT), ImmutableMap.of()); + assertFunction("transform_values(CAST (map(ARRAY[], ARRAY[]) AS MAP(BIGINT,VARCHAR)), (k, v) -> CAST(k AS VARCHAR) || v)", mapType(BIGINT, VARCHAR), ImmutableMap.of()); } @Test @@ -56,23 +56,23 @@ public void testNullValue() sequenceToNullMap.put(1, null); sequenceToNullMap.put(2, null); sequenceToNullMap.put(3, null); - assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY ['a', 'b', 'c']), (k, v) -> NULL)", new MapType(INTEGER, UNKNOWN), sequenceToNullMap); + assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY ['a', 'b', 'c']), (k, v) -> NULL)", mapType(INTEGER, UNKNOWN), sequenceToNullMap); Map mapWithNullValue = new HashMap<>(); mapWithNullValue.put(1, "a"); mapWithNullValue.put(2, "b"); mapWithNullValue.put(3, null); - assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY ['a', 'b', NULL]), (k, v) -> v)", new MapType(INTEGER, createVarcharType(1)), mapWithNullValue); - assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY [10, 11, NULL]), (k, v) -> to_base(v, 16))", new MapType(INTEGER, createVarcharType(64)), mapWithNullValue); - assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY ['10', '11', 'Invalid']), (k, v) -> to_base(TRY_CAST(v as BIGINT), 16))", new MapType(INTEGER, createVarcharType(64)), mapWithNullValue); + assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY ['a', 'b', NULL]), (k, v) -> v)", mapType(INTEGER, createVarcharType(1)), mapWithNullValue); + assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY [10, 11, NULL]), (k, v) -> to_base(v, 16))", mapType(INTEGER, createVarcharType(64)), mapWithNullValue); + assertFunction("transform_values(map(ARRAY[1, 2, 3], ARRAY ['10', '11', 'Invalid']), (k, v) -> to_base(TRY_CAST(v as BIGINT), 16))", mapType(INTEGER, createVarcharType(64)), mapWithNullValue); assertFunction( "transform_values(map(ARRAY[1, 2, 3], ARRAY [0, 0, 0]), (k, v) -> element_at(map(ARRAY[1, 2], ARRAY['a', 'b']), k + v))", - new MapType(INTEGER, createVarcharType(1)), + mapType(INTEGER, createVarcharType(1)), mapWithNullValue); assertFunction( "transform_values(map(ARRAY[1, 2, 3], ARRAY ['a', 'b', NULL]), (k, v) -> IF(v IS NULL, k + 1.0, k + 0.5))", - new MapType(INTEGER, DOUBLE), + mapType(INTEGER, DOUBLE), ImmutableMap.of(1, 1.5, 2, 2.5, 3, 4.0)); } @@ -82,22 +82,22 @@ public void testBasic() { assertFunction( "transform_values(map(ARRAY [1, 2, 3, 4], ARRAY [10, 20, 30, 40]), (k, v) -> k + v)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(1, 11, 2, 22, 3, 33, 4, 44)); assertFunction( "transform_values(map(ARRAY ['a', 'b', 'c', 'd'], ARRAY [1, 2, 3, 4]), (k, v) -> v * v)", - new MapType(createVarcharType(1), INTEGER), + mapType(createVarcharType(1), INTEGER), ImmutableMap.of("a", 1, "b", 4, "c", 9, "d", 16)); assertFunction( "transform_values(map(ARRAY ['a', 'b', 'c', 'd'], ARRAY [1, 2, 3, 4]), (k, v) -> k || CAST(v as VARCHAR))", - new MapType(createVarcharType(1), VARCHAR), + mapType(createVarcharType(1), VARCHAR), ImmutableMap.of("a", "a1", "b", "b2", "c", "c3", "d", "d4")); assertFunction( "transform_values(map(ARRAY[1, 2, 3], ARRAY [1.0, 1.4, 1.7]), (k, v) -> map(ARRAY[1, 2, 3], ARRAY['one', 'two', 'three'])[k] || '_' || CAST(v AS VARCHAR))", - new MapType(INTEGER, VARCHAR), + mapType(INTEGER, VARCHAR), ImmutableMap.of(1, "one_1.0", 2, "two_1.4", 3, "three_1.7")); } @@ -107,107 +107,107 @@ public void testTypeCombinations() { assertFunction( "transform_values(map(ARRAY [25, 26, 27], ARRAY [25, 26, 27]), (k, v) -> k + v)", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(25, 50, 26, 52, 27, 54)); assertFunction( "transform_values(map(ARRAY [25, 26, 27], ARRAY [26.1, 31.2, 37.1]), (k, v) -> CAST(v - k AS BIGINT))", - new MapType(INTEGER, BIGINT), + mapType(INTEGER, BIGINT), ImmutableMap.of(25, 1L, 26, 5L, 27, 10L)); assertFunction( "transform_values(map(ARRAY [25, 27], ARRAY [false, true]), (k, v) -> if(v, k + 1, k + 2))", - new MapType(INTEGER, INTEGER), + mapType(INTEGER, INTEGER), ImmutableMap.of(25, 27, 27, 28)); assertFunction( "transform_values(map(ARRAY [25, 26, 27], ARRAY ['abc', 'd', 'xy']), (k, v) -> k + length(v))", - new MapType(INTEGER, BIGINT), + mapType(INTEGER, BIGINT), ImmutableMap.of(25, 28L, 26, 27L, 27, 29L)); assertFunction( "transform_values(map(ARRAY [25, 26, 27], ARRAY [ARRAY ['a'], ARRAY ['a', 'c'], ARRAY ['a', 'b', 'c']]), (k, v) -> k + cardinality(v))", - new MapType(INTEGER, BIGINT), + mapType(INTEGER, BIGINT), ImmutableMap.of(25, 26L, 26, 28L, 27, 30L)); assertFunction( "transform_values(map(ARRAY [25.5, 26.75, 27.875], ARRAY [25, 26, 27]), (k, v) -> k - v)", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(25.5, 0.5, 26.75, 0.75, 27.875, 0.875)); assertFunction( "transform_values(map(ARRAY [25.5, 26.75, 27.875], ARRAY [25.0, 26.0, 27.0]), (k, v) -> k - v)", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(25.5, 0.5, 26.75, 0.75, 27.875, 0.875)); assertFunction( "transform_values(map(ARRAY [25.5, 27.5], ARRAY [false, true]), (k, v) -> if(v, k + 0.1, k + 0.2))", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(25.5, 25.7, 27.5, 27.6)); assertFunction( "transform_values(map(ARRAY [25.5, 26.5, 27.5], ARRAY ['a', 'def', 'xy']), (k, v) -> k + length(v))", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(25.5, 26.5, 26.5, 29.5, 27.5, 29.5)); assertFunction( "transform_values(map(ARRAY [25.5, 26.5, 27.5], ARRAY [ARRAY ['a'], ARRAY ['a', 'c'], ARRAY ['a', 'b', 'c']]), (k, v) -> k + cardinality(v))", - new MapType(DOUBLE, DOUBLE), + mapType(DOUBLE, DOUBLE), ImmutableMap.of(25.5, 26.5, 26.5, 28.5, 27.5, 30.5)); assertFunction( "transform_values(map(ARRAY [true, false], ARRAY [25, 26]), (k, v) -> k AND v = 25)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(true, true, false, false)); assertFunction( "transform_values(map(ARRAY [false, true], ARRAY [25.5, 26.5]), (k, v) -> k OR v > 100)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, false, true, true)); assertFunction( "transform_values(map(ARRAY [true, false], ARRAY [false, null]), (k, v) -> NOT k OR v)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, true, true, false)); assertFunction( "transform_values(map(ARRAY [false, true], ARRAY ['abc', 'def']), (k, v) -> NOT k AND v = 'abc')", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, true, true, false)); assertFunction( "transform_values(map(ARRAY [true, false], ARRAY [ARRAY ['a', 'b'], ARRAY ['a', 'b', 'c']]), (k, v) -> k OR cardinality(v) = 3)", - new MapType(BOOLEAN, BOOLEAN), + mapType(BOOLEAN, BOOLEAN), ImmutableMap.of(false, true, true, true)); assertFunction( "transform_values(map(ARRAY ['s0', 's1', 's2'], ARRAY [25, 26, 27]), (k, v) -> k || ':' || CAST(v as VARCHAR))", - new MapType(createVarcharType(2), VARCHAR), + mapType(createVarcharType(2), VARCHAR), ImmutableMap.of("s0", "s0:25", "s1", "s1:26", "s2", "s2:27")); assertFunction( "transform_values(map(ARRAY ['s0', 's1', 's2'], ARRAY [25.5, 26.5, 27.5]), (k, v) -> k || ':' || CAST(v as VARCHAR))", - new MapType(createVarcharType(2), VARCHAR), + mapType(createVarcharType(2), VARCHAR), ImmutableMap.of("s0", "s0:25.5", "s1", "s1:26.5", "s2", "s2:27.5")); assertFunction( "transform_values(map(ARRAY ['s0', 's2'], ARRAY [false, true]), (k, v) -> if(v, k, CAST(v AS VARCHAR)))", - new MapType(createVarcharType(2), VARCHAR), + mapType(createVarcharType(2), VARCHAR), ImmutableMap.of("s0", "false", "s2", "s2")); assertFunction( "transform_values(map(ARRAY ['s0', 's1', 's2'], ARRAY ['abc', 'def', 'xyz']), (k, v) -> k || ':' || v)", - new MapType(createVarcharType(2), VARCHAR), + mapType(createVarcharType(2), VARCHAR), ImmutableMap.of("s0", "s0:abc", "s1", "s1:def", "s2", "s2:xyz")); assertFunction( "transform_values(map(ARRAY ['s0', 's1', 's2'], ARRAY [ARRAY ['a', 'b'], ARRAY ['a', 'c'], ARRAY ['a', 'b', 'c']]), (k, v) -> k || ':' || array_max(v))", - new MapType(createVarcharType(2), VARCHAR), + mapType(createVarcharType(2), VARCHAR), ImmutableMap.of("s0", "s0:b", "s1", "s1:c", "s2", "s2:c")); assertFunction( "transform_values(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY [25, 26]), (k, v) -> if(v % 2 = 0, reverse(k), k))", - new MapType(new ArrayType(INTEGER), new ArrayType(INTEGER)), + mapType(new ArrayType(INTEGER), new ArrayType(INTEGER)), ImmutableMap.of(ImmutableList.of(1, 2), ImmutableList.of(1, 2), ImmutableList.of(3, 4), ImmutableList.of(4, 3))); assertFunction( "transform_values(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY [25.5, 26.5]), (k, v) -> CAST(k AS ARRAY(DOUBLE)) || v)", - new MapType(new ArrayType(INTEGER), new ArrayType(DOUBLE)), + mapType(new ArrayType(INTEGER), new ArrayType(DOUBLE)), ImmutableMap.of(ImmutableList.of(1, 2), ImmutableList.of(1., 2., 25.5), ImmutableList.of(3, 4), ImmutableList.of(3., 4., 26.5))); assertFunction( "transform_values(map(ARRAY [ARRAY [1, 2], ARRAY [3, 4]], ARRAY [false, true]), (k, v) -> if(v, reverse(k), k))", - new MapType(new ArrayType(INTEGER), new ArrayType(INTEGER)), + mapType(new ArrayType(INTEGER), new ArrayType(INTEGER)), ImmutableMap.of(ImmutableList.of(1, 2), ImmutableList.of(1, 2), ImmutableList.of(3, 4), ImmutableList.of(4, 3))); assertFunction( "transform_values(map(ARRAY [ARRAY [1, 2], ARRAY []], ARRAY ['a', 'ff']), (k, v) -> k || from_base(v, 16))", - new MapType(new ArrayType(INTEGER), new ArrayType(BIGINT)), + mapType(new ArrayType(INTEGER), new ArrayType(BIGINT)), ImmutableMap.of(ImmutableList.of(1, 2), ImmutableList.of(1L, 2L, 10L), ImmutableList.of(), ImmutableList.of(255L))); assertFunction( "transform_values(map(ARRAY [ARRAY [3, 4], ARRAY []], ARRAY [ARRAY ['a', 'b', 'c'], ARRAY ['a', 'c']]), (k, v) -> transform(k, x -> CAST(x AS VARCHAR)) || v)", - new MapType(new ArrayType(INTEGER), new ArrayType(VARCHAR)), + mapType(new ArrayType(INTEGER), new ArrayType(VARCHAR)), ImmutableMap.of(ImmutableList.of(3, 4), ImmutableList.of("3", "4", "a", "b", "c"), ImmutableList.of(), ImmutableList.of("a", "c"))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java index 99ed596745a02..70f686aacf065 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestPageProcessorCompiler.java @@ -20,13 +20,13 @@ import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.facebook.presto.spi.block.SliceArrayBlock; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.DeterminismEvaluator; import com.facebook.presto.sql.relational.InputReferenceExpression; import com.facebook.presto.sql.relational.RowExpression; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; @@ -192,7 +192,7 @@ private static DictionaryBlock createDictionaryBlock(Slice[] expectedValues, int for (int i = 0; i < positionCount; i++) { ids[i] = i % dictionarySize; } - return new DictionaryBlock(positionCount, new SliceArrayBlock(dictionarySize, expectedValues), ids); + return new DictionaryBlock(new SliceArrayBlock(dictionarySize, expectedValues), ids); } private static Slice[] createExpectedValues(int positionCount) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java index aebd9697329c5..9d1b9bcafd42e 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestRegexpFunctions.java @@ -15,9 +15,9 @@ import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.Slices; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java index d222d4e83a74f..0cf66eab5de6d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java @@ -19,7 +19,9 @@ import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; import org.testng.annotations.Test; import javax.annotation.Nullable; @@ -285,6 +287,98 @@ public static long bad(@SqlType(StandardTypes.BIGINT) long value, @IsNull @SqlNu } } + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter to only contain A-Z and 0-9 \\(starting with A-Z\\), but got bad on method .*") + public void testNonUpperCaseTypeParameters() + { + extractScalars(TypeParameterWithNonUpperCaseAnnotation.class); + } + + public static final class TypeParameterWithNonUpperCaseAnnotation + { + @ScalarFunction + @SqlType(StandardTypes.BIGINT) + @TypeParameter("bad") + public static long bad(@TypeParameter("array(bad)") Type type, @SqlType(StandardTypes.BIGINT) long value) + { + return value; + } + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter to only contain A-Z and 0-9 \\(starting with A-Z\\), but got 1E on method .*") + public void testLeadingNumericTypeParameters() + { + extractScalars(TypeParameterWithLeadingNumbers.class); + } + + public static final class TypeParameterWithLeadingNumbers + { + @ScalarFunction + @SqlType(StandardTypes.BIGINT) + @TypeParameter("1E") + public static long bad(@TypeParameter("array(1E)") Type type, @SqlType(StandardTypes.BIGINT) long value) + { + return value; + } + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter not to take parameters, but got E on method .*") + public void testNonPrimitiveTypeParameters() + { + extractScalars(TypeParameterWithNonPrimitiveAnnotation.class); + } + + public static final class TypeParameterWithNonPrimitiveAnnotation + { + @ScalarFunction + @SqlType(StandardTypes.BIGINT) + @TypeParameter("E") + public static long bad(@TypeParameter("E(VARCHAR)") Type type, @SqlType(StandardTypes.BIGINT) long value) + { + return value; + } + } + + @Test + public void testValidTypeParameters() + { + extractScalars(ValidTypeParameter.class); + } + + public static final class ValidTypeParameter + { + @ScalarFunction + @SqlType(StandardTypes.BIGINT) + public static long good1( + @TypeParameter("ROW(ARRAY(BIGINT),MAP(INTEGER,DECIMAL),SMALLINT,CHAR,BOOLEAN,DATE,TIMESTAMP,VARCHAR)") Type type, + @SqlType(StandardTypes.BIGINT) long value) + { + return value; + } + + @ScalarFunction + @SqlType(StandardTypes.BIGINT) + @TypeParameter("E12") + @TypeParameter("F34") + public static long good2( + @TypeParameter("ROW(ARRAY(E12),JSON,TIME,VARBINARY,ROW(ROW(F34)))") Type type, + @SqlType(StandardTypes.BIGINT) long value) + { + return value; + } + } + + @Test + public void testValidTypeParametersForConstructors() + { + extractScalars(ConstructorWithValidTypeParameters.class); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected type parameter not to take parameters, but got K on method .*") + public void testInvalidTypeParametersForConstructors() + { + extractScalars(ConstructorWithInvalidTypeParameters.class); + } + private static void extractParametricScalar(Class clazz) { new FunctionListBuilder().scalar(clazz); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java index 1e155a77e5125..bb7c3773091bd 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java @@ -17,11 +17,11 @@ import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.SqlVarbinary; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.LiteralParameter; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; @@ -37,6 +37,7 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.FUNCTION_NOT_FOUND; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Strings.repeat; import static java.lang.String.format; @@ -375,7 +376,7 @@ public void testSplit() @Test public void testSplitToMap() { - MapType expectedType = new MapType(VARCHAR, VARCHAR); + MapType expectedType = mapType(VARCHAR, VARCHAR); assertFunction("SPLIT_TO_MAP('', ',', '=')", expectedType, ImmutableMap.of()); assertFunction("SPLIT_TO_MAP('a=123,b=.4,c=,=d', ',', '=')", expectedType, ImmutableMap.of("a", "123", "b", ".4", "c", "", "", "d")); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java index ebb505de9aa18..ef6b4af9f0510 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java @@ -160,6 +160,38 @@ public void testFromBigEndian64() assertInvalidFunction("from_big_endian_64(from_hex('000000000000000011'))", INVALID_FUNCTION_ARGUMENT); } + @Test + public void testToIEEE754Binary32() + throws Exception + { + assertFunction("to_ieee754_32(CAST(0.0 AS REAL))", VARBINARY, sqlVarbinaryHex("00000000")); + assertFunction("to_ieee754_32(CAST(1.0 AS REAL))", VARBINARY, sqlVarbinaryHex("3F800000")); + assertFunction("to_ieee754_32(CAST(3.14 AS REAL))", VARBINARY, sqlVarbinaryHex("4048F5C3")); + assertFunction("to_ieee754_32(CAST(NAN() AS REAL))", VARBINARY, sqlVarbinaryHex("7FC00000")); + assertFunction("to_ieee754_32(CAST(INFINITY() AS REAL))", VARBINARY, sqlVarbinaryHex("7F800000")); + assertFunction("to_ieee754_32(CAST(-INFINITY() AS REAL))", VARBINARY, sqlVarbinaryHex("FF800000")); + assertFunction("to_ieee754_32(CAST(3.4028235E38 AS REAL))", VARBINARY, sqlVarbinaryHex("7F7FFFFF")); + assertFunction("to_ieee754_32(CAST(-3.4028235E38 AS REAL))", VARBINARY, sqlVarbinaryHex("FF7FFFFF")); + assertFunction("to_ieee754_32(CAST(1.4E-45 AS REAL))", VARBINARY, sqlVarbinaryHex("00000001")); + assertFunction("to_ieee754_32(CAST(-1.4E-45 AS REAL))", VARBINARY, sqlVarbinaryHex("80000001")); + } + + @Test + public void testToIEEE754Binary64() + throws Exception + { + assertFunction("to_ieee754_64(0.0)", VARBINARY, sqlVarbinaryHex("0000000000000000")); + assertFunction("to_ieee754_64(1.0)", VARBINARY, sqlVarbinaryHex("3FF0000000000000")); + assertFunction("to_ieee754_64(3.1415926)", VARBINARY, sqlVarbinaryHex("400921FB4D12D84A")); + assertFunction("to_ieee754_64(NAN())", VARBINARY, sqlVarbinaryHex("7FF8000000000000")); + assertFunction("to_ieee754_64(INFINITY())", VARBINARY, sqlVarbinaryHex("7FF0000000000000")); + assertFunction("to_ieee754_64(-INFINITY())", VARBINARY, sqlVarbinaryHex("FFF0000000000000")); + assertFunction("to_ieee754_64(1.7976931348623157E308)", VARBINARY, sqlVarbinaryHex("7FEFFFFFFFFFFFFF")); + assertFunction("to_ieee754_64(-1.7976931348623157E308)", VARBINARY, sqlVarbinaryHex("FFEFFFFFFFFFFFFF")); + assertFunction("to_ieee754_64(4.9E-324)", VARBINARY, sqlVarbinaryHex("0000000000000001")); + assertFunction("to_ieee754_64(-4.9E-324)", VARBINARY, sqlVarbinaryHex("8000000000000001")); + } + @Test public void testMd5() throws Exception @@ -214,6 +246,17 @@ public void testHashCode() assertEquals(VarbinaryOperators.hashCode(data), VARBINARY.hash(block, 0)); } + @Test + public void testCrc32() + throws Exception + { + assertFunction("crc32(to_utf8('CRC me!'))", BIGINT, 38028046L); + assertFunction("crc32(to_utf8('1234567890'))", BIGINT, 639479525L); + assertFunction("crc32(to_utf8(CAST(1234567890 AS VARCHAR)))", BIGINT, 639479525L); + assertFunction("crc32(to_utf8('ABCDEFGHIJK'))", BIGINT, 1129618807L); + assertFunction("crc32(to_utf8('ABCDEFGHIJKLM'))", BIGINT, 4223167559L); + } + private static String encodeBase64(byte[] value) { return Base64.getEncoder().encodeToString(value); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java index 687e15367909b..a61c042558fbe 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipFunction.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.RowType; import org.testng.annotations.Test; import java.util.List; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java index 590b1d5986635..dd9a177bf6c2f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestZipWithFunction.java @@ -13,9 +13,8 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -28,6 +27,7 @@ import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.facebook.presto.type.UnknownType.UNKNOWN; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static java.util.Arrays.asList; public class TestZipWithFunction @@ -66,7 +66,7 @@ public void testSameLength() ImmutableList.of("ac", "bd")); assertFunction("zip_with(ARRAY[MAP(ARRAY[CAST ('a' AS VARCHAR)], ARRAY[1]), MAP(ARRAY[CAST('b' AS VARCHAR)], ARRAY[2])], ARRAY[MAP(ARRAY['c'], ARRAY[3]), MAP()], (x, y) -> map_concat(x, y))", - new ArrayType(new MapType(VARCHAR, INTEGER)), + new ArrayType(mapType(VARCHAR, INTEGER)), ImmutableList.of(ImmutableMap.of("a", 1, "c", 3), ImmutableMap.of("b", 2))); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/window/TestMapAggFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/window/TestMapAggFunction.java index 28a203f86f284..8a763c83b0381 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/window/TestMapAggFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/window/TestMapAggFunction.java @@ -14,13 +14,13 @@ package com.facebook.presto.operator.window; import com.facebook.presto.spi.type.VarcharType; -import com.facebook.presto.type.MapType; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestMapAggFunction extends AbstractTestWindowFunction @@ -29,7 +29,7 @@ public class TestMapAggFunction public void testMapAgg() { assertWindowQuery("map_agg(orderkey, orderstatus) OVER(PARTITION BY orderdate)", - resultBuilder(TEST_SESSION, BIGINT, VarcharType.createVarcharType(1), new MapType(BIGINT, VarcharType.createVarcharType(1))) + resultBuilder(TEST_SESSION, BIGINT, VarcharType.createVarcharType(1), mapType(BIGINT, VarcharType.createVarcharType(1))) .row(1, "O", ImmutableMap.of(1, "O")) .row(2, "O", ImmutableMap.of(2, "O")) .row(3, "F", ImmutableMap.of(3, "F")) diff --git a/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java b/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java index d9c5b2fe853dc..877c6b678cb57 100644 --- a/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java +++ b/presto-main/src/test/java/com/facebook/presto/security/TestAccessControlManager.java @@ -263,6 +263,11 @@ public void checkCanSetUser(Principal principal, String userName) checkedUserName = userName; } + @Override + public void checkCanAccessCatalog(Identity identity, String catalogName) + { + } + @Override public void checkCanSetSystemSessionProperty(Identity identity, String propertyName) { @@ -276,6 +281,12 @@ public void checkCanSelectFromTable(Identity identity, CatalogSchemaTableName ta denySelectTable(table.toString()); } } + + @Override + public Set filterCatalogs(Identity identity, Set catalogs) + { + return catalogs; + } }; } } @@ -386,13 +397,13 @@ public void checkCanSetCatalogSessionProperty(Identity identity, String property } @Override - public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName, String grantee, boolean withGrantOption) { throw new UnsupportedOperationException(); } @Override - public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName, String revokee, boolean grantOptionFor) { throw new UnsupportedOperationException(); } diff --git a/presto-main/src/test/java/com/facebook/presto/security/TestFileBasedSystemAccessControl.java b/presto-main/src/test/java/com/facebook/presto/security/TestFileBasedSystemAccessControl.java new file mode 100644 index 0000000000000..7729bce92ce31 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/security/TestFileBasedSystemAccessControl.java @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.facebook.presto.metadata.QualifiedObjectName; +import com.facebook.presto.spi.CatalogSchemaName; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.security.AccessDeniedException; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.security.Privilege.SELECT; +import static com.facebook.presto.transaction.TransactionBuilder.transaction; +import static com.facebook.presto.transaction.TransactionManager.createTestTransactionManager; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; + +public class TestFileBasedSystemAccessControl +{ + private static final Identity alice = new Identity("alice", Optional.empty()); + private static final Identity bob = new Identity("bob", Optional.empty()); + private static final Identity admin = new Identity("admin", Optional.empty()); + private static final Identity nonAsciiUser = new Identity("\u0194\u0194\u0194", Optional.empty()); + private static final Set allCatalogs = ImmutableSet.of("secret", "open-to-all", "all-allowed", "alice-catalog", "allowed-absent", "\u0200\u0200\u0200"); + private static final QualifiedObjectName aliceTable = new QualifiedObjectName("alice-catalog", "schema", "table"); + private static final QualifiedObjectName aliceView = new QualifiedObjectName("alice-catalog", "schema", "view"); + private static final CatalogSchemaName aliceSchema = new CatalogSchemaName("alice-catalog", "schema"); + + @Test + public void testCatalogOperations() + { + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); + + transaction(transactionManager, accessControlManager) + .execute(transactionId -> { + assertEquals(accessControlManager.filterCatalogs(admin, allCatalogs), allCatalogs); + Set aliceCatalogs = ImmutableSet.of("open-to-all", "alice-catalog", "all-allowed"); + assertEquals(accessControlManager.filterCatalogs(alice, allCatalogs), aliceCatalogs); + Set bobCatalogs = ImmutableSet.of("open-to-all", "all-allowed"); + assertEquals(accessControlManager.filterCatalogs(bob, allCatalogs), bobCatalogs); + Set nonAsciiUserCatalogs = ImmutableSet.of("open-to-all", "all-allowed", "\u0200\u0200\u0200"); + assertEquals(accessControlManager.filterCatalogs(nonAsciiUser, allCatalogs), nonAsciiUserCatalogs); + }); + } + + @Test + public void testSchemaOperations() + { + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); + + transaction(transactionManager, accessControlManager) + .execute(transactionId -> { + Set aliceSchemas = ImmutableSet.of("schema"); + assertEquals(accessControlManager.filterSchemas(transactionId, alice, "alice-catalog", aliceSchemas), aliceSchemas); + assertEquals(accessControlManager.filterSchemas(transactionId, bob, "alice-catalog", aliceSchemas), ImmutableSet.of()); + + accessControlManager.checkCanCreateSchema(transactionId, alice, aliceSchema); + accessControlManager.checkCanDropSchema(transactionId, alice, aliceSchema); + accessControlManager.checkCanRenameSchema(transactionId, alice, aliceSchema, "new-schema"); + accessControlManager.checkCanShowSchemas(transactionId, alice, "alice-catalog"); + }); + assertThrows(AccessDeniedException.class, () -> transaction(transactionManager, accessControlManager).execute(transactionId -> { + accessControlManager.checkCanCreateSchema(transactionId, bob, aliceSchema); + })); + } + + @Test + public void testTableOperations() + { + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); + + transaction(transactionManager, accessControlManager) + .execute(transactionId -> { + Set aliceTables = ImmutableSet.of(new SchemaTableName("schema", "table")); + assertEquals(accessControlManager.filterTables(transactionId, alice, "alice-catalog", aliceTables), aliceTables); + assertEquals(accessControlManager.filterTables(transactionId, bob, "alice-catalog", aliceTables), ImmutableSet.of()); + + accessControlManager.checkCanCreateTable(transactionId, alice, aliceTable); + accessControlManager.checkCanDropTable(transactionId, alice, aliceTable); + accessControlManager.checkCanSelectFromTable(transactionId, alice, aliceTable); + accessControlManager.checkCanInsertIntoTable(transactionId, alice, aliceTable); + accessControlManager.checkCanDeleteFromTable(transactionId, alice, aliceTable); + accessControlManager.checkCanAddColumns(transactionId, alice, aliceTable); + accessControlManager.checkCanRenameColumn(transactionId, alice, aliceTable); + }); + assertThrows(AccessDeniedException.class, () -> transaction(transactionManager, accessControlManager).execute(transactionId -> { + accessControlManager.checkCanCreateTable(transactionId, bob, aliceTable); + })); + } + + @Test + public void testViewOperations() + throws Exception + { + TransactionManager transactionManager = createTestTransactionManager(); + AccessControlManager accessControlManager = newAccessControlManager(transactionManager); + + transaction(transactionManager, accessControlManager) + .execute(transactionId -> { + accessControlManager.checkCanCreateView(transactionId, alice, aliceView); + accessControlManager.checkCanDropView(transactionId, alice, aliceView); + accessControlManager.checkCanSelectFromView(transactionId, alice, aliceView); + accessControlManager.checkCanCreateViewWithSelectFromTable(transactionId, alice, aliceTable); + accessControlManager.checkCanCreateViewWithSelectFromView(transactionId, alice, aliceView); + accessControlManager.checkCanSetCatalogSessionProperty(transactionId, alice, "alice-catalog", "property"); + accessControlManager.checkCanGrantTablePrivilege(transactionId, alice, SELECT, aliceTable, "grantee", true); + accessControlManager.checkCanRevokeTablePrivilege(transactionId, alice, SELECT, aliceTable, "revokee", true); + }); + assertThrows(AccessDeniedException.class, () -> transaction(transactionManager, accessControlManager).execute(transactionId -> { + accessControlManager.checkCanCreateView(transactionId, bob, aliceView); + })); + } + + private AccessControlManager newAccessControlManager(TransactionManager transactionManager) + { + AccessControlManager accessControlManager = new AccessControlManager(transactionManager); + + String path = this.getClass().getClassLoader().getResource("catalog.json").getPath(); + accessControlManager.setSystemAccessControl(FileBasedSystemAccessControl.NAME, ImmutableMap.of("security.config-file", path)); + + return accessControlManager; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java b/presto-main/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java index 858a8c930ffe8..f46c36146780e 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java @@ -29,6 +29,7 @@ import java.net.URI; import java.util.Optional; +import java.util.OptionalDouble; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.execution.QueryState.RUNNING; @@ -66,6 +67,7 @@ public void testConstructor() 16, 17, 18, + 34, 19, 20.0, DataSize.valueOf("21GB"), @@ -123,6 +125,8 @@ public void testConstructor() assertEquals(basicInfo.getQueryStats().isFullyBlocked(), true); assertEquals(basicInfo.getQueryStats().getBlockedReasons(), ImmutableSet.of(BlockedReason.WAITING_FOR_MEMORY)); + assertEquals(basicInfo.getQueryStats().getProgressPercentage(), OptionalDouble.of(100)); + assertEquals(basicInfo.getErrorCode(), StandardErrorCode.ABANDONED_QUERY.toErrorCode()); assertEquals(basicInfo.getErrorType(), StandardErrorCode.ABANDONED_QUERY.toErrorCode().getType()); } diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java b/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java index 54348c242eca6..c85ea347547e9 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java @@ -41,6 +41,8 @@ import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_QUEUE; import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_RUN; import static io.airlift.units.DataSize.Unit.BYTE; +import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.HOURS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -59,7 +61,9 @@ public void testQueryStateInfo() groupRootAX, new DataSize(6000, BYTE), 1, + null, 10, + null, CAN_QUEUE, 0, new DataSize(4000, BYTE), @@ -71,7 +75,9 @@ public void testQueryStateInfo() groupRootAY, new DataSize(8000, BYTE), 1, + new Duration(10, HOURS), 10, + new Duration(1, DAYS), CAN_RUN, 0, new DataSize(0, BYTE), @@ -83,7 +89,9 @@ public void testQueryStateInfo() groupRootA, new DataSize(8000, BYTE), 1, + null, 10, + null, CAN_QUEUE, 1, new DataSize(4000, BYTE), @@ -95,7 +103,9 @@ public void testQueryStateInfo() groupRootB, new DataSize(8000, BYTE), 1, + new Duration(10, HOURS), 10, + new Duration(1, DAYS), CAN_QUEUE, 0, new DataSize(4000, BYTE), @@ -107,7 +117,9 @@ public void testQueryStateInfo() new ResourceGroupId("root"), new DataSize(10000, BYTE), 2, + null, 20, + null, CAN_QUEUE, 0, new DataSize(6000, BYTE), @@ -226,6 +238,7 @@ private QueryInfo createQueryInfo(String queryId, QueryState state, String query 100, 17, 18, + 34, 19, 20.0, DataSize.valueOf("21GB"), diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestResourceGroupStateInfo.java b/presto-main/src/test/java/com/facebook/presto/server/TestResourceGroupStateInfo.java new file mode 100644 index 0000000000000..dd4e62f819b4e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/TestResourceGroupStateInfo.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.resourceGroups.ResourceGroupId; +import com.google.common.collect.ImmutableList; +import io.airlift.json.JsonCodec; +import io.airlift.units.DataSize; +import io.airlift.units.Duration; +import org.joda.time.DateTime; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.OptionalDouble; + +import static com.facebook.presto.execution.QueryState.RUNNING; +import static com.facebook.presto.spi.resourceGroups.ResourceGroupState.CAN_RUN; +import static io.airlift.units.DataSize.Unit.BYTE; +import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static java.util.concurrent.TimeUnit.HOURS; +import static org.testng.Assert.assertEquals; + +public class TestResourceGroupStateInfo +{ + @Test + public void testJsonRoundTrip() + { + ResourceGroupId resourceGroupId = new ResourceGroupId(ImmutableList.of("test", "user")); + ResourceGroupStateInfo expected = new ResourceGroupStateInfo( + resourceGroupId, + CAN_RUN, + new DataSize(10, GIGABYTE), + new DataSize(100, BYTE), + 10, + 100, + new Duration(1, HOURS), + new Duration(10, HOURS), + ImmutableList.of(new QueryStateInfo( + new QueryId("test_query"), + RUNNING, + Optional.of(resourceGroupId), + "SELECT * FROM t", + DateTime.parse("2017-06-12T21:39:48.658Z"), + "test_user", + Optional.of("catalog"), + Optional.of("schema"), + Optional.empty(), + Optional.of(new QueryProgressStats( + DateTime.parse("2017-06-12T21:39:50.966Z"), + 150060, + 243, + 1541, + 566038, + 1680000, + 24, + 124539, + 8283750, + false, + OptionalDouble.empty())))), + 0); + JsonCodec codec = JsonCodec.jsonCodec(ResourceGroupStateInfo.class); + ResourceGroupStateInfo actual = codec.fromJson(codec.toJson(expected)); + + assertEquals(actual.getId(), resourceGroupId); + assertEquals(actual.getState(), CAN_RUN); + assertEquals(actual.getSoftMemoryLimit(), new DataSize(10, GIGABYTE)); + assertEquals(actual.getMemoryUsage(), new DataSize(100, BYTE)); + assertEquals(actual.getMaxRunningQueries(), 10); + assertEquals(actual.getRunningTimeLimit(), new Duration(1, HOURS)); + assertEquals(actual.getMaxQueuedQueries(), 100); + assertEquals(actual.getQueuedTimeLimit(), new Duration(10, HOURS)); + assertEquals(actual.getNumQueuedQueries(), 0); + assertEquals(actual.getRunningQueries().size(), 1); + QueryStateInfo queryStateInfo = actual.getRunningQueries().get(0); + assertEquals(queryStateInfo.getQueryId(), new QueryId("test_query")); + assertEquals(queryStateInfo.getQueryState(), RUNNING); + assertEquals(queryStateInfo.getResourceGroupId(), Optional.of(resourceGroupId)); + assertEquals(queryStateInfo.getQuery(), "SELECT * FROM t"); + assertEquals(queryStateInfo.getCreateTime(), DateTime.parse("2017-06-12T21:39:48.658Z")); + assertEquals(queryStateInfo.getUser(), "test_user"); + assertEquals(queryStateInfo.getCatalog(), Optional.of("catalog")); + assertEquals(queryStateInfo.getSchema(), Optional.of("schema")); + assertEquals(queryStateInfo.getResourceGroupChain(), Optional.empty()); + QueryProgressStats progressStats = queryStateInfo.getProgress().get(); + assertEquals(progressStats.getExecutionStartTime(), DateTime.parse("2017-06-12T21:39:50.966Z")); + assertEquals(progressStats.getElapsedTimeMillis(), 150060); + assertEquals(progressStats.getQueuedTimeMillis(), 243); + assertEquals(progressStats.getCpuTimeMillis(), 1541); + assertEquals(progressStats.getScheduledTimeMillis(), 566038); + assertEquals(progressStats.getBlockedTimeMillis(), 1680000); + assertEquals(progressStats.getPeakMemoryBytes(), 24); + assertEquals(progressStats.getInputRows(), 124539); + assertEquals(progressStats.getInputBytes(), 8283750); + assertEquals(progressStats.isBlocked(), false); + assertEquals(progressStats.getProgressPercentage(), OptionalDouble.empty()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java index 9b469284d850a..188ee2d1611a5 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java +++ b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java @@ -25,88 +25,129 @@ import com.facebook.presto.execution.TaskStatus; import com.facebook.presto.execution.TaskTestUtils; import com.facebook.presto.execution.TestSqlTaskManager; +import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.PrestoNode; import com.facebook.presto.server.HttpRemoteTaskFactory; import com.facebook.presto.server.TaskUpdateRequest; -import com.google.common.collect.ImmutableListMultimap; +import com.facebook.presto.spi.ErrorCode; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.testing.TestingHandleResolver; +import com.facebook.presto.type.TypeDeserializer; +import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableMultimap; -import io.airlift.http.client.HttpStatus; -import io.airlift.http.client.Request; -import io.airlift.http.client.Response; +import com.google.inject.Binder; +import com.google.inject.Injector; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Scopes; +import io.airlift.bootstrap.Bootstrap; import io.airlift.http.client.testing.TestingHttpClient; -import io.airlift.http.client.testing.TestingResponse; +import io.airlift.jaxrs.JsonMapper; +import io.airlift.jaxrs.testing.JaxrsTestingHttpProcessor; import io.airlift.json.JsonCodec; +import io.airlift.json.JsonModule; import io.airlift.units.Duration; import org.testng.annotations.Test; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.DefaultValue; +import javax.ws.rs.GET; +import javax.ws.rs.HeaderParam; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.UriInfo; + import java.net.URI; -import java.nio.charset.StandardCharsets; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiConsumer; import static com.facebook.presto.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT; -import static com.facebook.presto.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID; +import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.json.JsonBinder.jsonBinder; +import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; -import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; -import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; public class TestHttpRemoteTask { - // This timeout should never be reached because a daemon thread in test should fail the test and do proper cleanup. + // This 30 sec per-test timeout should never be reached because the test should fail and do proper cleanup after 20 sec. + private static final Duration IDLE_TIMEOUT = new Duration(3, SECONDS); + private static final Duration FAIL_TIMEOUT = new Duration(20, SECONDS); + private static final TaskManagerConfig TASK_MANAGER_CONFIG = new TaskManagerConfig() + // Shorten status refresh wait and info update interval so that we can have a shorter test timeout + .setStatusRefreshMaxWait(new Duration(IDLE_TIMEOUT.roundTo(MILLISECONDS) / 100, MILLISECONDS)) + .setInfoUpdateInterval(new Duration(IDLE_TIMEOUT.roundTo(MILLISECONDS) / 10, MILLISECONDS)); + + private static final boolean TRACE_HTTP = false; + @Test(timeOut = 30000) public void testRemoteTaskMismatch() - throws InterruptedException, ExecutionException + throws Exception { - Duration idleTimeout = new Duration(3, SECONDS); - Duration failTimeout = new Duration(20, SECONDS); + runTest(TestCase.TASK_MISMATCH); + } - JsonCodec taskStatusCodec = JsonCodec.jsonCodec(TaskStatus.class); - JsonCodec taskInfoCodec = JsonCodec.jsonCodec(TaskInfo.class); - TaskManagerConfig taskManagerConfig = new TaskManagerConfig(); + @Test(timeOut = 30000) + public void testRejectedExecutionWhenVersionIsHigh() + throws Exception + { + runTest(TestCase.TASK_MISMATCH_WHEN_VERSION_IS_HIGH); + } - // Shorten status refresh wait and info update interval so that we can have a shorter test timeout - taskManagerConfig.setStatusRefreshMaxWait(new Duration(idleTimeout.roundTo(MILLISECONDS) / 100, MILLISECONDS)); - taskManagerConfig.setInfoUpdateInterval(new Duration(idleTimeout.roundTo(MILLISECONDS) / 10, MILLISECONDS)); + @Test(timeOut = 30000) + public void testRejectedExecution() + throws Exception + { + runTest(TestCase.REJECTED_EXECUTION); + } + private void runTest(TestCase testCase) + throws Exception + { AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); - HttpProcessor httpProcessor = new HttpProcessor(taskStatusCodec, taskInfoCodec, lastActivityNanos); - TestingHttpClient testingHttpClient = new TestingHttpClient(httpProcessor); - - HttpRemoteTaskFactory httpRemoteTaskFactory = new HttpRemoteTaskFactory( - new QueryManagerConfig(), - taskManagerConfig, - testingHttpClient, - new TestSqlTaskManager.MockLocationFactory(), - taskStatusCodec, - taskInfoCodec, - JsonCodec.jsonCodec(TaskUpdateRequest.class), - new RemoteTaskStats()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, testCase); + + HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource); + RemoteTask remoteTask = httpRemoteTaskFactory.createRemoteTask( TEST_SESSION, new TaskId("test", 1, 2), - new PrestoNode("node-id", URI.create("http://192.0.1.2"), new NodeVersion("version"), false), + new PrestoNode("node-id", URI.create("http://fake.invalid/"), new NodeVersion("version"), false), TaskTestUtils.PLAN_FRAGMENT, ImmutableMultimap.of(), createInitialEmptyOutputBuffers(OutputBuffers.BufferType.BROADCAST), new NodeTaskMap.PartitionedSplitCountTracker(i -> { }), true); - httpProcessor.setInitialTaskInfo(remoteTask.getTaskInfo()); + testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); remoteTask.start(); CompletableFuture testComplete = new CompletableFuture<>(); asyncRun( - idleTimeout.roundTo(MILLISECONDS), - failTimeout.roundTo(MILLISECONDS), + IDLE_TIMEOUT.roundTo(MILLISECONDS), + FAIL_TIMEOUT.roundTo(MILLISECONDS), lastActivityNanos, () -> testComplete.complete(null), (message, cause) -> testComplete.completeExceptionally(new AssertionError(message, cause))); @@ -114,8 +155,72 @@ public void testRemoteTaskMismatch() httpRemoteTaskFactory.stop(); assertTrue(remoteTask.getTaskStatus().getState().isDone(), format("TaskStatus is not in a done state: %s", remoteTask.getTaskStatus())); - assertEquals(getOnlyElement(remoteTask.getTaskStatus().getFailures()).getErrorCode(), REMOTE_TASK_MISMATCH.toErrorCode()); assertTrue(remoteTask.getTaskInfo().getTaskStatus().getState().isDone(), format("TaskInfo is not in a done state: %s", remoteTask.getTaskInfo())); + + ErrorCode actualErrorCode = getOnlyElement(remoteTask.getTaskStatus().getFailures()).getErrorCode(); + switch (testCase) { + case TASK_MISMATCH: + case TASK_MISMATCH_WHEN_VERSION_IS_HIGH: + assertEquals(actualErrorCode, REMOTE_TASK_MISMATCH.toErrorCode()); + break; + case REJECTED_EXECUTION: + assertEquals(actualErrorCode, REMOTE_TASK_ERROR.toErrorCode()); + break; + default: + throw new UnsupportedOperationException(); + } + } + + private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskResource testingTaskResource) + throws Exception + { + Bootstrap app = new Bootstrap( + new JsonModule(), + new HandleJsonModule(), + new Module() + { + @Override + public void configure(Binder binder) + { + binder.bind(JsonMapper.class); + configBinder(binder).bindConfig(FeaturesConfig.class); + binder.bind(TypeRegistry.class).in(Scopes.SINGLETON); + binder.bind(TypeManager.class).to(TypeRegistry.class).in(Scopes.SINGLETON); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + newSetBinder(binder, Type.class); + jsonCodecBinder(binder).bindJsonCodec(TaskStatus.class); + jsonCodecBinder(binder).bindJsonCodec(TaskInfo.class); + jsonCodecBinder(binder).bindJsonCodec(TaskUpdateRequest.class); + } + + @Provides + private HttpRemoteTaskFactory createHttpRemoteTaskFactory( + JsonMapper jsonMapper, + JsonCodec taskStatusCodec, + JsonCodec taskInfoCodec, + JsonCodec taskUpdateRequestCodec) + { + JaxrsTestingHttpProcessor jaxrsTestingHttpProcessor = new JaxrsTestingHttpProcessor(URI.create("http://fake.invalid/"), testingTaskResource, jsonMapper); + TestingHttpClient testingHttpClient = new TestingHttpClient(jaxrsTestingHttpProcessor.setTrace(TRACE_HTTP)); + return new HttpRemoteTaskFactory( + new QueryManagerConfig(), + TASK_MANAGER_CONFIG, + testingHttpClient, + new TestSqlTaskManager.MockLocationFactory(), + taskStatusCodec, + taskInfoCodec, + taskUpdateRequestCodec, + new RemoteTaskStats()); + } + } + ); + Injector injector = app + .strictConfig() + .doNotInitializeLogging() + .initialize(); + HandleResolver handleResolver = injector.getInstance(HandleResolver.class); + handleResolver.addConnectorName("test", new TestingHandleResolver()); + return injector.getInstance(HttpRemoteTaskFactory.class); } private static void asyncRun(long idleTimeoutMillis, long failTimeoutMillis, AtomicLong lastActivityNanos, Runnable runAfterIdle, BiConsumer runAfterFail) @@ -146,52 +251,90 @@ private static void asyncRun(long idleTimeoutMillis, long failTimeoutMillis, Ato }).start(); } - private static class HttpProcessor implements TestingHttpClient.Processor + private enum TestCase + { + TASK_MISMATCH, + TASK_MISMATCH_WHEN_VERSION_IS_HIGH, + REJECTED_EXECUTION + } + + @Path("/task/{nodeId}") + public static class TestingTaskResource { private static final String INITIAL_TASK_INSTANCE_ID = "task-instance-id"; private static final String NEW_TASK_INSTANCE_ID = "task-instance-id-x"; - private final JsonCodec taskStatusCodec; - private final JsonCodec taskInfoCodec; + private final AtomicLong lastActivityNanos; + private final TestCase testCase; private TaskInfo initialTaskInfo; private TaskStatus initialTaskStatus; private long version; private TaskState taskState; + private String taskInstanceId = INITIAL_TASK_INSTANCE_ID; private long statusFetchCounter; - private String taskInstanceId = INITIAL_TASK_INSTANCE_ID; - public HttpProcessor(JsonCodec taskStatusCodec, JsonCodec taskInfoCodec, AtomicLong lastActivityNanos) + public TestingTaskResource(AtomicLong lastActivityNanos, TestCase testCase) { - this.taskStatusCodec = taskStatusCodec; - this.taskInfoCodec = taskInfoCodec; - this.lastActivityNanos = lastActivityNanos; + this.lastActivityNanos = requireNonNull(lastActivityNanos, "lastActivityNanos is null"); + this.testCase = requireNonNull(testCase, "testCase is null"); } - @Override - public synchronized Response handle(Request request) - throws Exception + @GET + @Path("{taskId}") + @Produces(MediaType.APPLICATION_JSON) + public synchronized TaskInfo getTaskInfo( + @PathParam("taskId") final TaskId taskId, + @HeaderParam(PRESTO_CURRENT_STATE) TaskState currentState, + @HeaderParam(PRESTO_MAX_WAIT) Duration maxWait, + @Context UriInfo uriInfo) { lastActivityNanos.set(System.nanoTime()); + return buildTaskInfo(); + } - ImmutableListMultimap.Builder headers = ImmutableListMultimap.builder(); - headers.put(PRESTO_TASK_INSTANCE_ID, taskInstanceId); - headers.put(CONTENT_TYPE, "application/json"); + @POST + @Path("{taskId}") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public synchronized TaskInfo createOrUpdateTask( + @PathParam("taskId") TaskId taskId, + TaskUpdateRequest taskUpdateRequest, + @Context UriInfo uriInfo) + { + lastActivityNanos.set(System.nanoTime()); + return buildTaskInfo(); + } - if (request.getUri().getPath().endsWith("/status")) { - statusFetchCounter++; - if (statusFetchCounter >= 10) { - // Change the task instance id after 10th fetch to simulate worker restart - taskInstanceId = NEW_TASK_INSTANCE_ID; - } - wait(Duration.valueOf(request.getHeader(PRESTO_MAX_WAIT)).roundTo(MILLISECONDS)); - return new TestingResponse(HttpStatus.OK, headers.build(), taskStatusCodec.toJson(buildTaskStatus()).getBytes(StandardCharsets.UTF_8)); - } - if ("DELETE".equals(request.getMethod())) { - taskState = TaskState.ABORTED; - } - return new TestingResponse(HttpStatus.OK, headers.build(), taskInfoCodec.toJson(buildTaskInfo()).getBytes(StandardCharsets.UTF_8)); + @GET + @Path("{taskId}/status") + @Produces(MediaType.APPLICATION_JSON) + public synchronized TaskStatus getTaskStatus( + @PathParam("taskId") TaskId taskId, + @HeaderParam(PRESTO_CURRENT_STATE) TaskState currentState, + @HeaderParam(PRESTO_MAX_WAIT) Duration maxWait, + @Context UriInfo uriInfo) + throws InterruptedException + { + lastActivityNanos.set(System.nanoTime()); + + wait(maxWait.roundTo(MILLISECONDS)); + return buildTaskStatus(); + } + + @DELETE + @Path("{taskId}") + @Produces(MediaType.APPLICATION_JSON) + public synchronized TaskInfo deleteTask( + @PathParam("taskId") TaskId taskId, + @QueryParam("abort") @DefaultValue("true") boolean abort, + @Context UriInfo uriInfo) + { + lastActivityNanos.set(System.nanoTime()); + + taskState = abort ? TaskState.ABORTED : TaskState.CANCELED; + return buildTaskInfo(); } public void setInitialTaskInfo(TaskInfo initialTaskInfo) @@ -200,6 +343,18 @@ public void setInitialTaskInfo(TaskInfo initialTaskInfo) this.initialTaskStatus = initialTaskInfo.getTaskStatus(); this.taskState = initialTaskStatus.getState(); this.version = initialTaskStatus.getVersion(); + switch (testCase) { + case TASK_MISMATCH_WHEN_VERSION_IS_HIGH: + // Make the initial version large enough. + // This way, the version number can't be reached if it is reset to 0. + version = 1_000_000; + break; + case TASK_MISMATCH: + case REJECTED_EXECUTION: + break; // do nothing + default: + throw new UnsupportedOperationException(); + } } private TaskInfo buildTaskInfo() @@ -216,6 +371,25 @@ private TaskInfo buildTaskInfo() private TaskStatus buildTaskStatus() { + statusFetchCounter++; + // Change the task instance id after 10th fetch to simulate worker restart + switch (testCase) { + case TASK_MISMATCH: + case TASK_MISMATCH_WHEN_VERSION_IS_HIGH: + if (statusFetchCounter == 10) { + taskInstanceId = NEW_TASK_INSTANCE_ID; + version = 0; + } + break; + case REJECTED_EXECUTION: + if (statusFetchCounter >= 10) { + throw new RejectedExecutionException(); + } + break; + default: + throw new UnsupportedOperationException(); + } + return new TaskStatus( initialTaskStatus.getTaskId(), taskInstanceId, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java b/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java index f8276e6a56ae4..2827964ecc4ac 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java @@ -28,9 +28,9 @@ import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.LikePredicate; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.StringLiteral; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -1268,7 +1268,7 @@ private static Object optimize(@Language("SQL") String expression) Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - IdentityLinkedHashMap expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, parsedExpression, emptyList()); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, parsedExpression, emptyList()); ExpressionInterpreter interpreter = expressionOptimizer(parsedExpression, METADATA, TEST_SESSION, expressionTypes); return interpreter.optimize(symbol -> { switch (symbol.getName().toLowerCase(ENGLISH)) { @@ -1315,7 +1315,7 @@ private static void assertRoundTrip(String expression) private static Object evaluate(Expression expression) { - IdentityLinkedHashMap expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyList()); + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyList()); ExpressionInterpreter interpreter = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes); return interpreter.evaluate(null); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestSqlToRowExpressionTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/TestSqlToRowExpressionTranslator.java index 9c8e8be1afe16..ef3c7f5c704ff 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestSqlToRowExpressionTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/TestSqlToRowExpressionTranslator.java @@ -22,8 +22,9 @@ import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.type.TypeRegistry; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; @@ -39,12 +40,12 @@ public void testPossibleExponentialOptimizationTime() FunctionRegistry functionRegistry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); Expression expression = new LongLiteral("1"); - IdentityLinkedHashMap types = new IdentityLinkedHashMap<>(); - types.put(expression, BIGINT); + ImmutableMap.Builder, Type> types = ImmutableMap.builder(); + types.put(NodeRef.of(expression), BIGINT); for (int i = 0; i < 100; i++) { expression = new CoalesceExpression(expression); - types.put(expression, BIGINT); + types.put(NodeRef.of(expression), BIGINT); } - SqlToRowExpressionTranslator.translate(expression, SCALAR, types, functionRegistry, typeManager, TEST_SESSION, true); + SqlToRowExpressionTranslator.translate(expression, SCALAR, types.build(), functionRegistry, typeManager, TEST_SESSION, true); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index e5fd33cafe300..b73e056c9c5f4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -40,13 +40,13 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.testing.TestingMetadata; import com.facebook.presto.transaction.TransactionManager; -import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import io.airlift.json.JsonCodec; @@ -65,7 +65,7 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.AMBIGUOUS_ATTRIBUTE; -import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.COLUMN_NAME_NOT_SPECIFIED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.COLUMN_TYPE_UNKNOWN; @@ -73,6 +73,8 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PROCEDURE_ARGUMENTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_SCHEMA_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_WINDOW_FRAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_COLUMN_ALIASES; @@ -92,6 +94,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.ORDER_BY_MUST_BE_IN_SELECT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SAMPLE_PERCENTAGE_OUT_OF_RANGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.STANDALONE_LAMBDA; @@ -338,18 +341,58 @@ public void testNestedAggregation() public void testAggregationsNotAllowed() throws Exception { - assertFails(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, "SELECT * FROM t1 WHERE sum(a) > 1"); - assertFails(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, "SELECT * FROM t1 GROUP BY sum(a)"); - assertFails(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, "SELECT * FROM t1 JOIN t2 ON sum(t1.a) = t2.a"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT * FROM t1 WHERE sum(a) > 1"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT * FROM t1 GROUP BY sum(a)"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT * FROM t1 JOIN t2 ON sum(t1.a) = t2.a"); } @Test public void testWindowsNotAllowed() throws Exception { - assertFails(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, "SELECT * FROM t1 WHERE foo() over () > 1"); - assertFails(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, "SELECT * FROM t1 GROUP BY rank() over ()"); - assertFails(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, "SELECT * FROM t1 JOIN t2 ON sum(t1.a) over () = t2.a"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT * FROM t1 WHERE foo() over () > 1"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT * FROM t1 GROUP BY rank() over ()"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT * FROM t1 JOIN t2 ON sum(t1.a) over () = t2.a"); + } + + @Test + public void testGrouping() + throws Exception + { + analyze("SELECT a, b, sum(c), grouping(a, b) FROM t1 GROUP BY GROUPING SETS ((a), (a, b))"); + analyze("SELECT grouping(t1.a) FROM t1 GROUP BY a"); + analyze("SELECT grouping(b) FROM t1 GROUP BY t1.b"); + analyze("SELECT grouping(a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a) FROM t1 GROUP BY a"); + } + + @Test + public void testGroupingNotAllowed() + throws Exception + { + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT a, b, sum(c) FROM t1 WHERE grouping(a, b) GROUP BY GROUPING SETS ((a), (a, b))"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT a, b, sum(c) FROM t1 GROUP BY grouping(a, b)"); + assertFails(CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, "SELECT t1.a, t1.b FROM t1 JOIN t2 ON grouping(t1.a, t1.b) > t2.a"); + + assertFails(INVALID_PROCEDURE_ARGUMENTS, "SELECT grouping(a) FROM t1"); + assertFails(INVALID_PROCEDURE_ARGUMENTS, "SELECT * FROM t1 ORDER BY grouping(a)"); + assertFails(INVALID_PROCEDURE_ARGUMENTS, "SELECT grouping(a) FROM t1 GROUP BY b"); + assertFails(INVALID_PROCEDURE_ARGUMENTS, "SELECT grouping(a.field) FROM (VALUES ROW(CAST(ROW(1) AS ROW(field BIGINT)))) t(a) GROUP BY a.field"); + + assertFails(REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING, "SELECT a FROM t1 GROUP BY a ORDER BY grouping(a)"); + } + + @Test + public void testGroupingTooManyArguments() + throws Exception + { + String grouping = "GROUPING(a, a, a, a, a, a, a, a, a, a, a, a, a, a, a," + + "a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a," + + "a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a," + + "a, a)"; + assertFails(INVALID_PROCEDURE_ARGUMENTS, String.format("SELECT a, b, %s + 1 FROM t1 GROUP BY GROUPING SETS ((a), (a, b))", grouping)); + assertFails(INVALID_PROCEDURE_ARGUMENTS, String.format("SELECT a, b, %s as g FROM t1 GROUP BY a, b HAVING g > 0", grouping)); + assertFails(INVALID_PROCEDURE_ARGUMENTS, String.format("SELECT a, b, rank() OVER (PARTITION BY %s) FROM t1 GROUP BY GROUPING SETS ((a), (a, b))", grouping)); + assertFails(INVALID_PROCEDURE_ARGUMENTS, String.format("SELECT a, b, rank() OVER (PARTITION BY a ORDER BY %s) FROM t1 GROUP BY GROUPING SETS ((a), (a, b))", grouping)); } @Test @@ -959,6 +1002,14 @@ public void testGroupByCase() assertFails(MUST_BE_AGGREGATE_OR_GROUP_BY, "SELECT CASE WHEN true THEN 0 ELSE a END, count(*) FROM t1"); } + @Test + public void testGroupingWithWrongColumnsAndNoGroupBy() + throws Exception + { + assertFails(INVALID_PROCEDURE_ARGUMENTS, "SELECT a, SUM(b), GROUPING(a, b, c, d) FROM t1 GROUP BY GROUPING SETS ((a, b), (c))"); + assertFails(INVALID_PROCEDURE_ARGUMENTS, "SELECT a, SUM(b), GROUPING(a, b) FROM t1"); + } + @Test public void testMismatchedUnionQueries() throws Exception @@ -1175,6 +1226,11 @@ public void testLambdaInAggregationContext() MUST_BE_AGGREGATE_OR_GROUP_BY, ".* must be an aggregate expression or appear in GROUP BY clause", "SELECT apply(1, y -> x + y) FROM (VALUES (1,2)) t(x, y) GROUP BY x+y"); + assertFails( + MUST_BE_AGGREGATE_OR_GROUP_BY, + ".* must be an aggregate expression or appear in GROUP BY clause", + "SELECT apply(1, x -> y + transform(array[1], z -> x)[1]) FROM (VALUES (1, 2)) t(x,y) GROUP BY y + transform(array[1], z -> x)[1]" + ); } @Test @@ -1190,23 +1246,27 @@ public void testLambdaInSubqueryContext() } @Test - public void testLambdaWithAggregation() + public void testLambdaWithAggregationAndGrouping() throws Exception { assertFails( - CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, - ".* Lambda expression cannot contain aggregations or window functions: .*", + CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, + ".* Lambda expression cannot contain aggregations, window functions or grouping operations: .*", "SELECT transform(ARRAY[1], y -> max(x)) FROM (VALUES 10) t(x)"); // use of aggregation/window function on lambda variable assertFails( - CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, - ".* Lambda expression cannot contain aggregations or window functions: .*", + CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, + ".* Lambda expression cannot contain aggregations, window functions or grouping operations: .*", "SELECT apply(1, x -> max(x)) FROM (VALUES (1,2)) t(x,y) GROUP BY y"); assertFails( - CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, - ".* Lambda expression cannot contain aggregations or window functions: .*", + CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, + ".* Lambda expression cannot contain aggregations, window functions or grouping operations: .*", "SELECT apply(CAST(ROW(1) AS ROW(someField BIGINT)), x -> max(x.someField)) FROM (VALUES (1,2)) t(x,y) GROUP BY y"); + assertFails( + CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING, + ".* Lambda expression cannot contain aggregations, window functions or grouping operations: .*", + "SELECT apply(1, x -> grouping(x)) FROM (VALUES (1, 2)) t(x, y) GROUP BY y"); } @Test @@ -1256,6 +1316,36 @@ public void testLambdaWithSubqueryInOrderBy() "SELECT count(*) FROM t1 GROUP BY a ORDER BY (SELECT apply(0, x -> x + b))"); } + @Test + public void testLambdaWithInvalidParameterCount() + { + assertFails(INVALID_PARAMETER_USAGE, "line 1:17: Expected a lambda that takes 1 argument\\(s\\) but got 2", "SELECT apply(5, (x, y) -> 6)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:17: Expected a lambda that takes 1 argument\\(s\\) but got 3", "SELECT apply(5, (x, y, z) -> 6)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:21: Expected a lambda that takes 1 argument\\(s\\) but got 2", "SELECT TRY(apply(5, (x, y) -> x + 1) / 0)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:21: Expected a lambda that takes 1 argument\\(s\\) but got 3", "SELECT TRY(apply(5, (x, y, z) -> x + 1) / 0)"); + + assertFails(INVALID_PARAMETER_USAGE, "line 1:29: Expected a lambda that takes 1 argument\\(s\\) but got 2", "SELECT filter(ARRAY [5, 6], (x, y) -> x = 5)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:29: Expected a lambda that takes 1 argument\\(s\\) but got 3", "SELECT filter(ARRAY [5, 6], (x, y, z) -> x = 5)"); + + assertFails(INVALID_PARAMETER_USAGE, "line 1:52: Expected a lambda that takes 2 argument\\(s\\) but got 1", "SELECT map_filter(map(ARRAY [5, 6], ARRAY [5, 6]), (x) -> x = 1)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:52: Expected a lambda that takes 2 argument\\(s\\) but got 3", "SELECT map_filter(map(ARRAY [5, 6], ARRAY [5, 6]), (x, y, z) -> x = y + z)"); + + assertFails(INVALID_PARAMETER_USAGE, "line 1:33: Expected a lambda that takes 2 argument\\(s\\) but got 1", "SELECT reduce(ARRAY [5, 20], 0, (s) -> s, s -> s)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:33: Expected a lambda that takes 2 argument\\(s\\) but got 3", "SELECT reduce(ARRAY [5, 20], 0, (s, x, z) -> s + x, s -> s + z)"); + + assertFails(INVALID_PARAMETER_USAGE, "line 1:32: Expected a lambda that takes 1 argument\\(s\\) but got 2", "SELECT transform(ARRAY [5, 6], (x, y) -> x + y)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:32: Expected a lambda that takes 1 argument\\(s\\) but got 3", "SELECT transform(ARRAY [5, 6], (x, y, z) -> x + y + z)"); + + assertFails(INVALID_PARAMETER_USAGE, "line 1:49: Expected a lambda that takes 2 argument\\(s\\) but got 1", "SELECT transform_keys(map(ARRAY[1], ARRAY [2]), k -> k)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:52: Expected a lambda that takes 2 argument\\(s\\) but got 3", "SELECT transform_keys(MAP(ARRAY['a'], ARRAY['b']), (k, v, x) -> k + 1)"); + + assertFails(INVALID_PARAMETER_USAGE, "line 1:51: Expected a lambda that takes 2 argument\\(s\\) but got 1", "SELECT transform_values(map(ARRAY[1], ARRAY [2]), k -> k)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:51: Expected a lambda that takes 2 argument\\(s\\) but got 3", "SELECT transform_values(map(ARRAY[1], ARRAY [2]), (k, v, x) -> k + 1)"); + + assertFails(INVALID_PARAMETER_USAGE, "line 1:39: Expected a lambda that takes 2 argument\\(s\\) but got 1", "SELECT zip_with(ARRAY[1], ARRAY['a'], x -> x)"); + assertFails(INVALID_PARAMETER_USAGE, "line 1:39: Expected a lambda that takes 2 argument\\(s\\) but got 3", "SELECT zip_with(ARRAY[1], ARRAY['a'], (x, y, z) -> (x, y, z))"); + } + @Test public void testInvalidDelete() throws Exception @@ -1341,6 +1431,26 @@ public void testQuantifiedComparisonExpression() assertFails(TYPE_MISMATCH, "SELECT cast(NULL AS HyperLogLog) = ANY (VALUES cast(NULL AS HyperLogLog))"); } + @Test + public void testJoinUnnest() + throws Exception + { + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) CROSS JOIN UNNEST(x)"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN UNNEST(x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN UNNEST(x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN UNNEST(x) ON true"); + } + + @Test + public void testJoinLateral() + throws Exception + { + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) CROSS JOIN LATERAL(VALUES x)"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN LATERAL(VALUES x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN LATERAL(VALUES x) ON true"); + analyze("SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN LATERAL(VALUES x) ON true"); + } + @BeforeMethod(alwaysRun = true) public void setup() throws Exception diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 4742d33eb4903..4e66a6ed77a53 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -40,7 +40,7 @@ public void testDefaults() .setDistributedJoinsEnabled(true) .setFastInequalityJoins(true) .setColocatedJoinsEnabled(false) - .setJoinReorderingEnabled(false) + .setJoinReorderingEnabled(true) .setRedistributeWrites(true) .setOptimizeMetadataQueries(false) .setOptimizeHashGeneration(true) @@ -49,6 +49,7 @@ public void testDefaults() .setDictionaryAggregation(false) .setLegacyArrayAgg(false) .setLegacyMapSubscript(false) + .setNewMapBlock(true) .setRegexLibrary(JONI) .setRe2JDfaStatesLimit(Integer.MAX_VALUE) .setRe2JDfaRetries(5) @@ -62,7 +63,8 @@ public void testDefaults() .setIterativeOptimizerEnabled(true) .setIterativeOptimizerTimeout(new Duration(3, MINUTES)) .setExchangeCompressionEnabled(false) - .setEnableIntermediateAggregations(false)); + .setEnableIntermediateAggregations(false) + .setPushAggregationThroughJoin(true)); } @Test @@ -75,11 +77,12 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-array-agg", "true") .put("deprecated.legacy-order-by", "true") .put("deprecated.legacy-map-subscript", "true") + .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") .put("distributed-joins-enabled", "false") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") - .put("reorder-joins", "true") + .put("reorder-joins", "false") .put("redistribute-writes", "false") .put("optimizer.optimize-metadata-queries", "true") .put("optimizer.optimize-hash-generation", "false") @@ -87,6 +90,7 @@ public void testExplicitPropertyMappings() .put("optimizer.optimize-mixed-distinct-aggregations", "true") .put("optimizer.push-table-write-through-union", "false") .put("optimizer.dictionary-aggregation", "true") + .put("optimizer.push-aggregation-through-join", "false") .put("regex-library", "RE2J") .put("re2j.dfa-states-limit", "42") .put("re2j.dfa-retries", "42") @@ -105,11 +109,12 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-array-agg", "true") .put("deprecated.legacy-order-by", "true") .put("deprecated.legacy-map-subscript", "true") + .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") .put("distributed-joins-enabled", "false") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") - .put("reorder-joins", "true") + .put("reorder-joins", "false") .put("redistribute-writes", "false") .put("optimizer.optimize-metadata-queries", "true") .put("optimizer.optimize-hash-generation", "false") @@ -117,6 +122,7 @@ public void testExplicitPropertyMappings() .put("optimizer.optimize-mixed-distinct-aggregations", "true") .put("optimizer.push-table-write-through-union", "false") .put("optimizer.dictionary-aggregation", "true") + .put("optimizer.push-aggregation-through-join", "false") .put("regex-library", "RE2J") .put("re2j.dfa-states-limit", "42") .put("re2j.dfa-retries", "42") @@ -137,7 +143,7 @@ public void testExplicitPropertyMappings() .setDistributedJoinsEnabled(false) .setFastInequalityJoins(false) .setColocatedJoinsEnabled(true) - .setJoinReorderingEnabled(true) + .setJoinReorderingEnabled(false) .setRedistributeWrites(false) .setOptimizeMetadataQueries(true) .setOptimizeHashGeneration(false) @@ -145,8 +151,10 @@ public void testExplicitPropertyMappings() .setOptimizeMixedDistinctAggregations(true) .setPushTableWriteThroughUnion(false) .setDictionaryAggregation(true) + .setPushAggregationThroughJoin(false) .setLegacyArrayAgg(true) .setLegacyMapSubscript(true) + .setNewMapBlock(false) .setRegexLibrary(RE2J) .setRe2JDfaStatesLimit(42) .setRe2JDfaRetries(42) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java index 43bc4ab8f8c6c..61d72b87572bb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/PageProcessorBenchmark.java @@ -29,8 +29,8 @@ import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.testing.TestingSession; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.openjdk.jmh.annotations.Benchmark; @@ -174,7 +174,7 @@ private RowExpression rowExpression(String expression, Type type) } Map types = builder.build(); - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList()); + Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, METADATA, SQL_PARSER, types, inputReferenceExpression, emptyList()); return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, METADATA.getFunctionRegistry(), METADATA.getTypeManager(), TEST_SESSION, true); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java index 1e51c4f9fd1d6..cc640a9cc049c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinCompiler.java @@ -24,6 +24,7 @@ import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; +import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -193,7 +194,8 @@ public void testMultiChannel(boolean hashEnabled) // verify channel count assertEquals(hashStrategy.getChannelCount(), outputChannels.size()); // verify size - long sizeInBytes = channels.stream() + int instanceSize = ClassLayout.parseClass(hashStrategy.getClass()).instanceSize(); + long sizeInBytes = instanceSize + channels.stream() .flatMap(List::stream) .mapToLong(Block::getRetainedSizeInBytes) .sum(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestPageFunctionCompiler.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestPageFunctionCompiler.java new file mode 100644 index 0000000000000..bf4874aee6156 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestPageFunctionCompiler.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.gen; + +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.operator.project.PageProjection; +import com.facebook.presto.operator.project.SelectedPositions; +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.sql.relational.RowExpression; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.function.Supplier; + +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; +import static com.facebook.presto.spi.function.OperatorType.ADD; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.Expressions.field; +import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +public class TestPageFunctionCompiler +{ + @Test + public void testFailureDoesNotCorruptFutureResults() + throws Exception + { + PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager()); + + RowExpression add10 = call( + Signature.internalOperator(ADD, BIGINT.getTypeSignature(), ImmutableList.of(BIGINT.getTypeSignature(), BIGINT.getTypeSignature())), + BIGINT, + field(0, BIGINT), + constant(10L, BIGINT)); + + Supplier projectionSupplier = functionCompiler.compileProjection(add10); + PageProjection projection = projectionSupplier.get(); + + // process good page and verify we got the expected number of result rows + Page goodPage = createLongBlockPage(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + Block goodResult = projection.project(SESSION, goodPage, SelectedPositions.positionsRange(0, goodPage.getPositionCount())); + assertEquals(goodPage.getPositionCount(), goodResult.getPositionCount()); + + // addition will throw due to integer overflow + Page badPage = createLongBlockPage(0, 1, 2, 3, 4, Long.MAX_VALUE); + try { + projection.project(SESSION, badPage, SelectedPositions.positionsRange(0, 100)); + fail("expected exception"); + } + catch (PrestoException e) { + assertEquals(e.getErrorCode(), NUMERIC_VALUE_OUT_OF_RANGE.toErrorCode()); + } + + // running the good page should still work + // if block builder in generated code was not reset properly, we could get junk results after the failure + goodResult = projection.project(SESSION, goodPage, SelectedPositions.positionsRange(0, goodPage.getPositionCount())); + assertEquals(goodPage.getPositionCount(), goodResult.getPositionCount()); + } + + private static Page createLongBlockPage(long... values) + { + BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(values.length); + for (long value : values) { + BIGINT.writeLong(builder, value); + } + return new Page(builder.build()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java index c1853d2869a52..9714a705e6d6e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestVarArgsToArrayAdapterGenerator.java @@ -116,6 +116,7 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in false, nCopies(arity, false), nCopies(arity, false), + nCopies(arity, Optional.empty()), methodHandleAndConstructor.getMethodHandle(), Optional.of(methodHandleAndConstructor.getConstructor()), isDeterministic()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java index 2130d0c6a7dc0..df90f78b2853b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java @@ -1336,23 +1336,25 @@ private void testNumericTypeTranslation(NumericValues columnValues, NumericValue public void testVarcharComparedToCharExpression() throws Exception { + String maxCodePoint = new String(Character.toChars(Character.MAX_CODE_POINT)); + // greater than or equal - testSimpleComparison(greaterThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.greaterThan(createCharType(10), Slices.utf8Slice("12345678"))); + testSimpleComparison(greaterThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.greaterThan(createCharType(10), utf8Slice("123456788" + maxCodePoint))); testSimpleComparison(greaterThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("1234567890", VARCHAR)), C_CHAR, Range.greaterThanOrEqual(createCharType(10), Slices.utf8Slice("1234567890"))); testSimpleComparison(greaterThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("12345678901", VARCHAR)), C_CHAR, Range.greaterThan(createCharType(10), Slices.utf8Slice("1234567890"))); // greater than - testSimpleComparison(greaterThan(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.greaterThan(createCharType(10), Slices.utf8Slice("12345678"))); + testSimpleComparison(greaterThan(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.greaterThan(createCharType(10), utf8Slice("123456788" + maxCodePoint))); testSimpleComparison(greaterThan(cast(C_CHAR, VARCHAR), stringLiteral("1234567890", VARCHAR)), C_CHAR, Range.greaterThan(createCharType(10), Slices.utf8Slice("1234567890"))); testSimpleComparison(greaterThan(cast(C_CHAR, VARCHAR), stringLiteral("12345678901", VARCHAR)), C_CHAR, Range.greaterThan(createCharType(10), Slices.utf8Slice("1234567890"))); // less than or equal - testSimpleComparison(lessThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.lessThanOrEqual(createCharType(10), Slices.utf8Slice("12345678"))); + testSimpleComparison(lessThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.lessThanOrEqual(createCharType(10), utf8Slice("123456788" + maxCodePoint))); testSimpleComparison(lessThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("1234567890", VARCHAR)), C_CHAR, Range.lessThanOrEqual(createCharType(10), Slices.utf8Slice("1234567890"))); testSimpleComparison(lessThanOrEqual(cast(C_CHAR, VARCHAR), stringLiteral("12345678901", VARCHAR)), C_CHAR, Range.lessThanOrEqual(createCharType(10), Slices.utf8Slice("1234567890"))); // less than - testSimpleComparison(lessThan(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.lessThanOrEqual(createCharType(10), Slices.utf8Slice("12345678"))); + testSimpleComparison(lessThan(cast(C_CHAR, VARCHAR), stringLiteral("123456789", VARCHAR)), C_CHAR, Range.lessThanOrEqual(createCharType(10), utf8Slice("123456788" + maxCodePoint))); testSimpleComparison(lessThan(cast(C_CHAR, VARCHAR), stringLiteral("1234567890", VARCHAR)), C_CHAR, Range.lessThan(createCharType(10), Slices.utf8Slice("1234567890"))); testSimpleComparison(lessThan(cast(C_CHAR, VARCHAR), stringLiteral("12345678901", VARCHAR)), C_CHAR, Range.lessThanOrEqual(createCharType(10), Slices.utf8Slice("1234567890"))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index e019e21c6d5ef..fd0aa2bec8bd6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -47,6 +48,8 @@ import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.facebook.presto.type.UnknownType; import com.google.common.base.Preconditions; import com.google.common.base.Predicates; @@ -152,9 +155,8 @@ public void testAggregation() lessThan(CE, DE), greaterThan(AE, bigintLiteral(2)), equals(EE, FE))), - ImmutableMap.of(C, fakeFunction("test"), D, fakeFunction("test")), - ImmutableMap.of(C, fakeFunctionHandle("test", AGGREGATE), D, fakeFunctionHandle("test", AGGREGATE)), - ImmutableMap.of(), + ImmutableMap.of(C, new Aggregation(fakeFunction("test"), fakeFunctionHandle("test", AGGREGATE), Optional.empty()), + D, new Aggregation(fakeFunction("test"), fakeFunctionHandle("test", AGGREGATE), Optional.empty())), ImmutableList.of(ImmutableList.of(A, B, C)), AggregationNode.Step.FINAL, Optional.empty(), @@ -179,8 +181,6 @@ public void testGroupByEmpty() newId(), filter(baseTableScan, FALSE_LITERAL), ImmutableMap.of(), - ImmutableMap.of(), - ImmutableMap.of(), ImmutableList.of(ImmutableList.of()), AggregationNode.Step.FINAL, Optional.empty(), @@ -238,7 +238,7 @@ public void testTopN() equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), - 1, ImmutableList.of(A), ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST), true); + 1, ImmutableList.of(A), ImmutableMap.of(A, SortOrder.ASC_NULLS_LAST), TopNNode.Step.PARTIAL); Expression effectivePredicate = EffectivePredicateExtractor.extract(node, TYPES); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java index 01db56e686c5a..d33a34d7a256d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEqualityInference.java @@ -365,13 +365,13 @@ public void testExpressionsThatMayReturnNullOnNonNullInput() private static Predicate matchesSymbolScope(final Predicate symbolScope) { - return expression -> Iterables.all(DependencyExtractor.extractUnique(expression), symbolScope); + return expression -> Iterables.all(SymbolsExtractor.extractUnique(expression), symbolScope); } private static Predicate matchesStraddlingScope(final Predicate symbolScope) { return expression -> { - Set symbols = DependencyExtractor.extractUnique(expression); + Set symbols = SymbolsExtractor.extractUnique(expression); return Iterables.any(symbols, symbolScope) && Iterables.any(symbols, not(symbolScope)); }; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 94e7fae7e0c7b..218e7ffb8014c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -15,12 +15,14 @@ import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.optimizations.AddLocalExchanges; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.ValuesNode; @@ -46,15 +48,17 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.lateral; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; -import static com.facebook.presto.sql.planner.optimizations.Predicates.isInstanceOfAny; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.tests.QueryTemplate.queryTemplate; +import static com.facebook.presto.util.MorePredicates.isInstanceOfAny; import static io.airlift.slice.Slices.utf8Slice; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -217,7 +221,7 @@ private static int countOfMatchingNodes(Plan plan, Predicate predicate @Test public void testRemoveUnreferencedScalarInputApplyNodes() { - assertPlanContainsNoApplyOrJoin("SELECT (SELECT 1)"); + assertPlanContainsNoApplyOrAnyJoin("SELECT (SELECT 1)"); } @Test @@ -230,10 +234,10 @@ public void testSubqueryPruning() queryTemplate("SELECT COUNT(*) FROM (SELECT %subquery% FROM orders)") .replaceAll(subqueries) - .forEach(this::assertPlanContainsNoApplyOrJoin); + .forEach(this::assertPlanContainsNoApplyOrAnyJoin); // TODO enable when pruning apply nodes works for this kind of query - // assertPlanContainsNoApplyOrJoin("SELECT * FROM orders WHERE true OR " + subquery); + // assertPlanContainsNoApplyOrAnyJoin("SELECT * FROM orders WHERE true OR " + subquery); } @Test @@ -247,16 +251,16 @@ public void testJoinOutputPruning() anyTree( tableScan("region", ImmutableMap.of("REGIONKEY_RIGHT", "regionkey")))) ) - .withNumberOfOutputColumns(1) - .withOutputs(ImmutableList.of("NATIONKEY")) + .withNumberOfOutputColumns(1) + .withOutputs(ImmutableList.of("NATIONKEY")) ); } - private void assertPlanContainsNoApplyOrJoin(String sql) + private void assertPlanContainsNoApplyOrAnyJoin(String sql) { assertFalse( searchFrom(plan(sql, LogicalPlanner.Stage.OPTIMIZED).getRoot()) - .where(isInstanceOfAny(ApplyNode.class, JoinNode.class, IndexJoinNode.class, SemiJoinNode.class)) + .where(isInstanceOfAny(ApplyNode.class, JoinNode.class, IndexJoinNode.class, SemiJoinNode.class, LateralJoinNode.class)) .matches(), "Unexpected node for query: " + sql); } @@ -264,24 +268,68 @@ private void assertPlanContainsNoApplyOrJoin(String sql) @Test public void testCorrelatedSubqueries() { - assertPlan( + assertPlanWithOptimizerFiltering( "SELECT orderkey FROM orders WHERE 3 = (SELECT orderkey)", LogicalPlanner.Stage.OPTIMIZED, anyTree( filter("BIGINT '3' = X", - apply(ImmutableList.of("X"), - ImmutableMap.of(), + lateral( + ImmutableList.of("X"), tableScan("orders", ImmutableMap.of("X", "orderkey")), node(EnforceSingleRowNode.class, project( node(ValuesNode.class) - )))))); + ))))), + planOptimizer -> !(planOptimizer instanceof AddLocalExchanges)); } + /** + * Handling of correlated IN pulls up everything possible to the generated outer join condition. + * This test ensures uncorrelated conditions are pushed back down. + */ @Test - public void testDoubleNestedCorrelatedSubqueries() + public void testCorrelatedInUncorrelatedFiltersPushDown() { assertPlan( + "SELECT orderkey, comment IN (SELECT clerk FROM orders s WHERE s.orderkey = o.orderkey AND s.orderkey < 7) FROM lineitem o", + anyTree( + node(JoinNode.class, + anyTree(tableScan("lineitem")), + anyTree( + filter("orderkey < BIGINT '7'", // pushed down + tableScan("orders", ImmutableMap.of("orderkey", "orderkey")) + ) + ) + ) + ) + ); + } + + /** + * Handling of correlated in predicate involves group by over all symbols from source. Once aggregation is added to the plan, + * it prevents pruning of the unreferenced symbols. However, the aggregation's result doesn't actually depended on those symbols + * and this test makes sure the symbols are pruned first. + */ + @Test + public void testSymbolsPrunedInCorrelatedInPredicateSource() + { + assertPlan( + "SELECT orderkey, comment IN (SELECT clerk FROM orders s WHERE s.orderkey = o.orderkey AND s.orderkey < 7) FROM lineitem o", + anyTree( + node(JoinNode.class, + anyTree(strictTableScan("lineitem", ImmutableMap.of( + "orderkey", "orderkey", + "comment", "comment"))), + anyTree(tableScan("orders")) + ) + ) + ); + } + + @Test + public void testDoubleNestedCorrelatedSubqueries() + { + assertPlanWithOptimizerFiltering( "SELECT orderkey FROM orders o " + "WHERE 3 IN (SELECT o.custkey FROM lineitem l WHERE (SELECT l.orderkey = o.orderkey))", LogicalPlanner.Stage.OPTIMIZED, @@ -294,20 +342,21 @@ public void testDoubleNestedCorrelatedSubqueries() "O", "orderkey", "C", "custkey"))), anyTree( - apply(ImmutableList.of("L"), - ImmutableMap.of(), + lateral( + ImmutableList.of("L"), tableScan("lineitem", ImmutableMap.of("L", "orderkey")), node(EnforceSingleRowNode.class, project( node(ValuesNode.class) - )))))))); + ))))))), + planOptimizer -> !(planOptimizer instanceof AddLocalExchanges)); } @Test public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin() { assertPlan( - "SELECT orderkey FROM orders WHERE EXISTS(SELECT 1 WHERE orderkey = 3)", // EXISTS maps to count(*) = 1 + "SELECT orderkey FROM orders WHERE EXISTS(SELECT 1 WHERE orderkey = 3)", // EXISTS maps to count(*) > 0 anyTree( filter("FINAL_COUNT > BIGINT '0'", any( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPlanMatchingFramework.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPlanMatchingFramework.java index a9ee01661326b..ab2221e8f0edb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPlanMatchingFramework.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPlanMatchingFramework.java @@ -250,14 +250,6 @@ public void testDuplicateAliases() tableScan("lineitem").withAlias("ORDERS_OK", columnReference("lineitem", "orderkey")))))); } - @Test(expectedExceptions = {IllegalStateException.class}, expectedExceptionsMessageRegExp = ".*already bound in.*") - public void testBindMultipleAliasesSameExpression() - { - assertMinimallyOptimizedPlan("SELECT orderkey FROM lineitem", - output(ImmutableList.of("ORDERKEY", "TWO"), - tableScan("lineitem", ImmutableMap.of("FIRST", "orderkey", "SECOND", "orderkey")))); - } - @Test(expectedExceptions = {IllegalStateException.class}, expectedExceptionsMessageRegExp = "missing expression for alias .*") public void testProjectLimitsScope() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java index 518c425df57d3..1625e645a70d9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java @@ -23,6 +23,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; @@ -43,4 +44,26 @@ public void testNonStraddlingJoinExpression() "LINEITEM_OK", "orderkey", "LINEITEM_LINENUMBER", "linenumber"))))))); } + + @Test + public void testPushDownToLhsOfSemiJoin() + { + assertPlan("SELECT quantity FROM (SELECT * FROM lineitem WHERE orderkey IN (SELECT orderkey FROM orders)) " + + "WHERE linenumber = 2", + anyTree( + semiJoin("LINE_ORDER_KEY", "ORDERS_ORDER_KEY", "SEMI_JOIN_RESULT", + anyTree( + filter("LINE_NUMBER = 2", + tableScan("lineitem", ImmutableMap.of( + "LINE_ORDER_KEY", "orderkey", + "LINE_NUMBER", "linenumber", + "LINE_QUANTITY", "quantity") + ) + ) + ), + anyTree(tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey"))) + ) + ) + ); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java index 63c6c181bbd0b..971dfc336b36b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java @@ -13,28 +13,25 @@ */ package com.facebook.presto.sql.planner; -import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; -import java.util.Map; import java.util.Optional; +import java.util.Set; import static org.testng.AssertJUnit.assertEquals; public class TestSortExpressionExtractor { - private static final Map BUILD_LAYOUT = ImmutableMap.of( - new Symbol("b1"), 1, - new Symbol("b2"), 2); + private static final Set BUILD_SYMBOLS = ImmutableSet.of(new Symbol("b1"), new Symbol("b2")); @Test public void testGetSortExpression() @@ -42,47 +39,60 @@ public void testGetSortExpression() assertGetSortExpression( new ComparisonExpression( ComparisonExpressionType.GREATER_THAN, - new FieldReference(11), - new FieldReference(1)), - 1); + new SymbolReference("p1"), + new SymbolReference("b1")), + "b1"); assertGetSortExpression( new ComparisonExpression( ComparisonExpressionType.LESS_THAN_OR_EQUAL, - new FieldReference(2), - new FieldReference(11)), - 2); + new SymbolReference("b2"), + new SymbolReference("p1")), + "b2"); assertGetSortExpression( new ComparisonExpression( ComparisonExpressionType.GREATER_THAN, - new FieldReference(2), - new FieldReference(11)), - 2); + new SymbolReference("b2"), + new SymbolReference("p1")), + "b2"); assertGetSortExpression( new ComparisonExpression( ComparisonExpressionType.GREATER_THAN, - new FieldReference(1), - new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.ADD, new FieldReference(2), new FieldReference(11)))); + new SymbolReference("b2"), + new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new SymbolReference("p1")))), + "b2"); assertGetSortExpression( new ComparisonExpression( ComparisonExpressionType.GREATER_THAN, - new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new FieldReference(1))), - new FieldReference(11))); + new SymbolReference("b2"), + new FunctionCall(QualifiedName.of("random"), ImmutableList.of(new SymbolReference("p1"))))); + + assertGetSortExpression( + new ComparisonExpression( + ComparisonExpressionType.GREATER_THAN, + new SymbolReference("b1"), + new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.ADD, new SymbolReference("b2"), new SymbolReference("p1")))); + + assertGetSortExpression( + new ComparisonExpression( + ComparisonExpressionType.GREATER_THAN, + new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new SymbolReference("b1"))), + new SymbolReference("p1"))); } private static void assertGetSortExpression(Expression expression) { - Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_LAYOUT, expression); + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); assertEquals(Optional.empty(), actual); } - private static void assertGetSortExpression(Expression expression, int expectedChannel) + private static void assertGetSortExpression(Expression expression, String expectedSymbol) { - Optional expected = Optional.of(new SortExpression(expectedChannel)); - Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_LAYOUT, expression); + Optional expected = Optional.of(new SymbolReference(expectedSymbol)); + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); assertEquals(expected, actual); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index 69b73602bba32..f91866fd8610e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; @@ -37,6 +38,8 @@ import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.WindowFrame; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; @@ -181,23 +184,21 @@ public void testValidAggregation() throws Exception { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - Map functions = ImmutableMap.of( - aggregationSymbol, new Signature( - "sum", - FunctionKind.AGGREGATE, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(DOUBLE.getTypeSignature()), - false)); - Map aggregations = ImmutableMap.of(aggregationSymbol, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference()))); PlanNode node = new AggregationNode( newId(), baseTableScan, - aggregations, - functions, - ImmutableMap.of(), + ImmutableMap.of(aggregationSymbol, new Aggregation( + new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())), + new Signature( + "sum", + FunctionKind.AGGREGATE, + ImmutableList.of(), + ImmutableList.of(), + DOUBLE.getTypeSignature(), + ImmutableList.of(DOUBLE.getTypeSignature()), + false), + Optional.empty())), ImmutableList.of(ImmutableList.of(columnA, columnB)), SINGLE, Optional.empty(), @@ -243,23 +244,21 @@ public void testInvalidAggregationFunctionCall() throws Exception { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - Map functions = ImmutableMap.of( - aggregationSymbol, new Signature( - "sum", - FunctionKind.AGGREGATE, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(DOUBLE.getTypeSignature()), - false)); - Map aggregations = ImmutableMap.of(aggregationSymbol, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnA.toSymbolReference()))); // should be columnC PlanNode node = new AggregationNode( newId(), baseTableScan, - aggregations, - functions, - ImmutableMap.of(), + ImmutableMap.of(aggregationSymbol, new Aggregation( + new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnA.toSymbolReference())), + new Signature( + "sum", + FunctionKind.AGGREGATE, + ImmutableList.of(), + ImmutableList.of(), + DOUBLE.getTypeSignature(), + ImmutableList.of(DOUBLE.getTypeSignature()), + false), + Optional.empty())), ImmutableList.of(ImmutableList.of(columnA, columnB)), SINGLE, Optional.empty(), @@ -273,23 +272,21 @@ public void testInvalidAggregationFunctionSignature() throws Exception { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - Map functions = ImmutableMap.of( - aggregationSymbol, new Signature( - "sum", - FunctionKind.AGGREGATE, - ImmutableList.of(), - ImmutableList.of(), - BIGINT.getTypeSignature(), // should be DOUBLE - ImmutableList.of(DOUBLE.getTypeSignature()), - false)); - Map aggregations = ImmutableMap.of(aggregationSymbol, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference()))); PlanNode node = new AggregationNode( newId(), baseTableScan, - aggregations, - functions, - ImmutableMap.of(), + ImmutableMap.of(aggregationSymbol, new Aggregation( + new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())), + new Signature( + "sum", + FunctionKind.AGGREGATE, + ImmutableList.of(), + ImmutableList.of(), + BIGINT.getTypeSignature(), // should be DOUBLE + ImmutableList.of(DOUBLE.getTypeSignature()), + false), + Optional.empty())), ImmutableList.of(ImmutableList.of(columnA, columnB)), SINGLE, Optional.empty(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java index 951f7b13ec185..7232a0701f4eb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java @@ -17,6 +17,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.tree.FunctionCall; @@ -47,8 +48,8 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada AggregationNode aggregationNode = (AggregationNode) node; FunctionCall expectedCall = callMaker.getExpectedValue(symbolAliases); - for (Map.Entry assignment : aggregationNode.getAggregations().entrySet()) { - if (expectedCall.equals(assignment.getValue())) { + for (Map.Entry assignment : aggregationNode.getAggregations().entrySet()) { + if (expectedCall.equals(assignment.getValue().getCall())) { checkState(!result.isPresent(), "Ambiguous function calls in %s", aggregationNode); result = Optional.of(assignment.getKey()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java index 6cb68800f475c..abc57531342b7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -55,7 +56,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); AggregationNode aggregationNode = (AggregationNode) node; @@ -71,7 +72,7 @@ public MatchResult detailMatches(PlanNode node, Session session, Metadata metada List aggregationsWithMask = aggregationNode.getAggregations() .entrySet() .stream() - .filter(entry -> entry.getValue().isDistinct()) + .filter(entry -> entry.getValue().getCall().isDistinct()) .map(entry -> entry.getKey()) .collect(Collectors.toList()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Alias.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java similarity index 88% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Alias.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java index 81f071ad88ffa..49cb8097ad655 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Alias.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -24,13 +25,13 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -public class Alias +public class AliasMatcher implements Matcher { private final Optional alias; private final RvalueMatcher matcher; - Alias(Optional alias, RvalueMatcher matcher) + AliasMatcher(Optional alias, RvalueMatcher matcher) { this.alias = requireNonNull(alias, "alias is null"); this.matcher = requireNonNull(matcher, "matcher is null"); @@ -51,7 +52,7 @@ public boolean shapeMatches(PlanNode node) * higher up the tree. */ @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { Optional symbol = matcher.getAssignedSymbol(node, session, metadata, symbolAliases); if (symbol.isPresent() && alias.isPresent()) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasPresent.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasPresent.java new file mode 100644 index 0000000000000..40a9da90be0a7 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AliasPresent.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +/** + * Just check alias is present; return self mapping from symbolAliasesMap. + */ +class AliasPresent + implements RvalueMatcher +{ + private final String alias; + + AliasPresent(String alias) + { + this.alias = requireNonNull(alias, "alias can not be null"); + } + + @Override + public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + return symbolAliases.getOptional(alias) + .map(Symbol::from); + } + + @Override + public String toString() + { + return "has " + alias; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java new file mode 100644 index 0000000000000..440713f687636 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/AssignUniqueIdMatcher.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; + +public class AssignUniqueIdMatcher + implements RvalueMatcher +{ + @Override + public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + if (!(node instanceof AssignUniqueId)) { + return Optional.empty(); + } + + AssignUniqueId assignUniqueIdNode = (AssignUniqueId) node; + + return Optional.of(assignUniqueIdNode.getIdColumn()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java index 07e3bd8294158..af575a145dc20 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java @@ -33,10 +33,12 @@ import java.util.List; import java.util.Map; +import java.util.function.Predicate; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static io.airlift.testing.Closeables.closeAllRuntimeException; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; import static org.testng.Assert.fail; public class BasePlanTest @@ -91,34 +93,42 @@ protected void assertPlan(String sql, PlanMatchPattern pattern) protected void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern) { - queryRunner.inTransaction(transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, stage); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); - return null; - }); + List optimizers = queryRunner.getPlanOptimizers(true); + + assertPlan(sql, stage, pattern, optimizers); } protected void assertPlanWithOptimizers(String sql, PlanMatchPattern pattern, List optimizers) + { + assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED, pattern, optimizers); + } + + protected void assertPlanWithOptimizerFiltering(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern, Predicate optimizerPredicate) + { + List optimizers = queryRunner.getPlanOptimizers(true).stream() + .filter(optimizerPredicate) + .collect(toList()); + + assertPlan(sql, stage, pattern, optimizers); + } + + protected void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern, List optimizers) { queryRunner.inTransaction(transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers, LogicalPlanner.Stage.OPTIMIZED); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers, stage); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); return null; }); } protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMatchPattern pattern) { - LocalQueryRunner queryRunner = getQueryRunner(); List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), new PruneUnreferencedOutputs(), new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections()))); - queryRunner.inTransaction(transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers, LogicalPlanner.Stage.OPTIMIZED); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); - return null; - }); + + assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED, pattern, optimizers); } protected Plan plan(String sql) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java index 5cdaf0a9c6e75..fd9b9bbb66561 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BaseStrictSymbolsMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -34,6 +35,7 @@ public BaseStrictSymbolsMatcher(Function> getActual) this.getActual = requireNonNull(getActual, "getActual is null"); } + @Override public boolean shapeMatches(PlanNode node) { try { @@ -45,12 +47,10 @@ public boolean shapeMatches(PlanNode node) } } - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + @Override + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); - - Set expected = getExpectedSymbols(node, session, metadata, symbolAliases); - return new MatchResult(getActual.apply(node).equals(getExpectedSymbols(node, session, metadata, symbolAliases))); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java index c58899d2fb2aa..6c794a8dc8acf 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/CorrelationMatcher.java @@ -14,8 +14,11 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import java.util.List; @@ -39,29 +42,44 @@ public class CorrelationMatcher @Override public boolean shapeMatches(PlanNode node) { - return node instanceof ApplyNode; + return node instanceof ApplyNode || node instanceof LateralJoinNode; } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { - checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + checkState( + shapeMatches(node), + "Plan testing framework error: shapeMatches returned false in detailMatches in %s", + this.getClass().getName()); - ApplyNode applyNode = (ApplyNode) node; - - if (correlation.size() != applyNode.getCorrelation().size()) { + List actualCorrelation = getCorrelation(node); + if (this.correlation.size() != actualCorrelation.size()) { return NO_MATCH; } int i = 0; - for (String alias : correlation) { - if (!symbolAliases.get(alias).equals(applyNode.getCorrelation().get(i++).toSymbolReference())) { + for (String alias : this.correlation) { + if (!symbolAliases.get(alias).equals(actualCorrelation.get(i++).toSymbolReference())) { return NO_MATCH; } } return match(); } + private List getCorrelation(PlanNode node) + { + if (node instanceof ApplyNode) { + return ((ApplyNode) node).getCorrelation(); + } + else if (node instanceof LateralJoinNode) { + return ((LateralJoinNode) node).getCorrelation(); + } + else { + throw new IllegalStateException("Unexpected plan node: " + node); + } + } + @Override public String toString() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java index 04a5c350ad271..8e7e320e13a24 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/EquiJoinClauseProvider.java @@ -33,4 +33,13 @@ public JoinNode.EquiJoinClause getExpectedValue(SymbolAliases aliases) { return new JoinNode.EquiJoinClause(left.toSymbol(aliases), right.toSymbol(aliases)); } + + @Override + public String toString() + { + return "EquiJoinClauseProvider{" + + "left=" + left + + ", right=" + right + + '}'; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java index 34df8c1d334f3..9e5771c530a26 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExchangeMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -45,7 +46,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java index a400ff8a332c8..00e48f578a4fd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java @@ -25,6 +25,8 @@ import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.IsNotNullPredicate; +import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.Node; @@ -105,6 +107,30 @@ protected Boolean visitCast(Cast actual, Expression expectedExpression) return process(actual.getExpression(), expected.getExpression()); } + @Override + protected Boolean visitIsNullPredicate(IsNullPredicate actual, Expression expectedExpression) + { + if (!(expectedExpression instanceof IsNullPredicate)) { + return false; + } + + IsNullPredicate expected = (IsNullPredicate) expectedExpression; + + return process(actual.getValue(), expected.getValue()); + } + + @Override + protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Expression expectedExpression) + { + if (!(expectedExpression instanceof IsNotNullPredicate)) { + return false; + } + + IsNotNullPredicate expected = (IsNotNullPredicate) expectedExpression; + + return process(actual.getValue(), expected.getValue()); + } + @Override protected Boolean visitInPredicate(InPredicate actual, Expression expectedExpression) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java index dae55c500b113..69877793f041b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/FilterMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -40,7 +41,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java index 0e1aa82a73660..ea4dd92b99b2a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/GroupIdMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.GroupIdNode; @@ -48,7 +49,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java index 4aca9a6328374..d73a005c8b22e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -56,7 +57,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); @@ -87,8 +88,8 @@ public MatchResult detailMatches(PlanNode node, Session session, Metadata metada Set actual = ImmutableSet.copyOf(joinNode.getCriteria()); Set expected = equiCriteria.stream() - .map(maker -> maker.getExpectedValue(symbolAliases)) - .collect(toImmutableSet()); + .map(maker -> maker.getExpectedValue(symbolAliases)) + .collect(toImmutableSet()); return new MatchResult(expected.equals(actual)); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java index 9e78a06212d48..6d9573bbe7f93 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/LimitMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -42,7 +43,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node)); return MatchResult.match(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java new file mode 100644 index 0000000000000..d79f38a45f98d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/MarkDistinctMatcher.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class MarkDistinctMatcher + implements Matcher +{ + private final PlanTestSymbol markerSymbol; + private final List distinctSymbols; + private final Optional hashSymbol; + + public MarkDistinctMatcher(PlanTestSymbol markerSymbol, List distinctSymbols, Optional hashSymbol) + { + this.markerSymbol = requireNonNull(markerSymbol, "markerSymbol is null"); + this.distinctSymbols = ImmutableList.copyOf(distinctSymbols); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof MarkDistinctNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + MarkDistinctNode markDistinctNode = (MarkDistinctNode) node; + + if (!markDistinctNode.getHashSymbol().equals(hashSymbol.map(alias -> alias.toSymbol(symbolAliases)))) { + return NO_MATCH; + } + + if (!ImmutableSet.copyOf(markDistinctNode.getDistinctSymbols()) + .equals(distinctSymbols.stream().map(alias -> alias.toSymbol(symbolAliases)).collect(toImmutableSet()))) { + return NO_MATCH; + } + + return match(markerSymbol.toString(), markDistinctNode.getMarkerSymbol().toSymbolReference()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("markerSymbol", markerSymbol) + .add("distinctSymbols", distinctSymbols) + .add("hashSymbol", hashSymbol) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java index c7c12900ce5e0..106e5a6ce9637 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/Matcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -62,10 +63,11 @@ public interface Matcher * node if shapeMatches didn't return true for the same node. * * @param node The node to apply the matching tests to + * @param planNodeCost The computed cost of plan node * @param session The session information for the query * @param metadata The metadata for the query * @param symbolAliases The SymbolAliases containing aliases from the nodes sources * @return a MatchResult with information about the success of the match */ - MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases); + MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java index 268327dc9f3a6..a74219a22c441 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/NotPlanNodeMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -22,7 +23,8 @@ import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; -final class NotPlanNodeMatcher implements Matcher +final class NotPlanNodeMatcher + implements Matcher { private final Class excludedNodeClass; @@ -38,7 +40,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); return match(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java index 0f00124bf72fc..f10fd79b83102 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OutputMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -44,7 +45,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { int i = 0; for (String alias : aliases) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java index eeaabbd5db068..386fb13b7a66e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java @@ -14,9 +14,13 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import static com.facebook.presto.sql.planner.iterative.Plans.resolveGroupReferences; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textLogicalPlan; import static java.lang.String.format; @@ -24,12 +28,23 @@ public final class PlanAssert { private PlanAssert() {} - public static void assertPlan(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern) + public static void assertPlan(Session session, Metadata metadata, CostCalculator costCalculator, Plan actual, PlanMatchPattern pattern) { - MatchResult matches = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata), pattern); + assertPlan(session, metadata, costCalculator, actual, Lookup.noLookup(), pattern); + } + + public static void assertPlan(Session session, Metadata metadata, CostCalculator costCalculator, Plan actual, Lookup lookup, PlanMatchPattern pattern) + { + MatchResult matches = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata, actual.getPlanNodeCosts(), lookup), pattern); if (!matches.isMatch()) { - String logicalPlan = textLogicalPlan(actual.getRoot(), actual.getTypes(), metadata, session); - throw new AssertionError(format("Plan does not match, expected [\n\n%s\n] but found [\n\n%s\n]", pattern, logicalPlan)); + String formattedPlan = textLogicalPlan(actual.getRoot(), actual.getTypes(), metadata, costCalculator, session); + PlanNode resolvedPlan = resolveGroupReferences(actual.getRoot(), lookup); + String resolvedFormattedPlan = textLogicalPlan(resolvedPlan, actual.getTypes(), metadata, costCalculator, session); + throw new AssertionError(format( + "Plan does not match, expected [\n\n%s\n] but found [\n\n%s\n] which resolves to [\n\n%s\n]", + pattern, + formattedPlan, + resolvedFormattedPlan)); } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanCostMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanCostMatcher.java new file mode 100644 index 0000000000000..06793f43f4839 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanCostMatcher.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import static java.util.Objects.requireNonNull; + +public class PlanCostMatcher + implements Matcher +{ + private final PlanNodeCost expectedCost; + + PlanCostMatcher(PlanNodeCost expectedCost) + { + this.expectedCost = requireNonNull(expectedCost, "expectedCost is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return true; + } + + @Override + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + return new MatchResult(expectedCost.equals(cost)); + } + + @Override + public String toString() + { + return "expectedCost(" + expectedCost + ")"; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index a7d1a46d0ea49..733e83e24475b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.predicate.Domain; @@ -22,13 +23,16 @@ import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.ExceptNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -45,12 +49,14 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.IntStream; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; @@ -158,6 +164,29 @@ public static PlanMatchPattern aggregation( return result; } + public static PlanMatchPattern markDistinct( + String markerSymbol, + List distinctSymbols, + PlanMatchPattern source) + { + return node(MarkDistinctNode.class, source).with(new MarkDistinctMatcher( + new SymbolAlias(markerSymbol), + toSymbolAliases(distinctSymbols), + Optional.empty())); + } + + public static PlanMatchPattern markDistinct( + String markerSymbol, + List distinctSymbols, + String hashSymbol, + PlanMatchPattern source) + { + return node(MarkDistinctNode.class, source).with(new MarkDistinctMatcher( + new SymbolAlias(markerSymbol), + toSymbolAliases(distinctSymbols), + Optional.of(new SymbolAlias(hashSymbol)))); + } + public static PlanMatchPattern window( ExpectedValueProvider specification, List> windowFunctions, @@ -261,6 +290,12 @@ public static PlanMatchPattern union(PlanMatchPattern... sources) return node(UnionNode.class, sources); } + public static PlanMatchPattern assignUniqueId(String uniqueSymbolAlias, PlanMatchPattern source) + { + return node(AssignUniqueId.class, source) + .withAlias(uniqueSymbolAlias, new AssignUniqueIdMatcher()); + } + public static PlanMatchPattern intersect(PlanMatchPattern... sources) { return node(IntersectNode.class, sources); @@ -296,17 +331,52 @@ public static PlanMatchPattern apply(List correlationSymbolAliases, Map< return result; } + public static PlanMatchPattern lateral(List correlationSymbolAliases, PlanMatchPattern inputPattern, PlanMatchPattern subqueryPattern) + { + return node(LateralJoinNode.class, inputPattern, subqueryPattern) + .with(new CorrelationMatcher(correlationSymbolAliases)); + } + public static PlanMatchPattern groupingSet(List> groups, String groupIdAlias, PlanMatchPattern source) { return node(GroupIdNode.class, source).with(new GroupIdMatcher(groups, ImmutableMap.of(), groupIdAlias)); } - public static PlanMatchPattern values(Map values) + private static PlanMatchPattern values( + Map aliasToIndex, + Optional expectedOutputSymbolCount, + Optional>> expectedRows + ) { - PlanMatchPattern result = node(ValuesNode.class); - values.entrySet().forEach( - alias -> result.withAlias(alias.getKey(), new ValuesMatcher(alias.getValue()))); - return result; + return node(ValuesNode.class).with(new ValuesMatcher(aliasToIndex, expectedOutputSymbolCount, expectedRows)); + } + + private static PlanMatchPattern values(List aliases, Optional>> expectedRows) + { + return values( + Maps.uniqueIndex(IntStream.range(0, aliases.size()).boxed().iterator(), aliases::get), + Optional.of(aliases.size()), + expectedRows); + } + + public static PlanMatchPattern values(Map aliasToIndex) + { + return values(aliasToIndex, Optional.empty(), Optional.empty()); + } + + public static PlanMatchPattern values(String ... aliases) + { + return values(ImmutableList.copyOf(aliases)); + } + + public static PlanMatchPattern values(List aliases, List> expectedRows) + { + return values(aliases, Optional.of(expectedRows)); + } + + public static PlanMatchPattern values(List aliases) + { + return values(aliases, Optional.empty()); } public static PlanMatchPattern limit(long limit, PlanMatchPattern source) @@ -339,12 +409,12 @@ List shapeMatches(PlanNode node) return states.build(); } - MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) { SymbolAliases.Builder newAliases = SymbolAliases.builder(); for (Matcher matcher : matchers) { - MatchResult matchResult = matcher.detailMatches(node, session, metadata, symbolAliases); + MatchResult matchResult = matcher.detailMatches(node, planNodeCost, session, metadata, symbolAliases); if (!matchResult.isMatch()) { return NO_MATCH; } @@ -360,6 +430,11 @@ public PlanMatchPattern with(Matcher matcher) return this; } + public PlanMatchPattern withAlias(String alias) + { + return withAlias(Optional.of(alias), new AliasPresent(alias)); + } + public PlanMatchPattern withAlias(String alias, RvalueMatcher matcher) { return withAlias(Optional.of(alias), matcher); @@ -367,7 +442,7 @@ public PlanMatchPattern withAlias(String alias, RvalueMatcher matcher) public PlanMatchPattern withAlias(Optional alias, RvalueMatcher matcher) { - matchers.add(new Alias(alias, matcher)); + matchers.add(new AliasMatcher(alias, matcher)); return this; } @@ -382,6 +457,11 @@ public PlanMatchPattern withNumberOfOutputColumns(int numberOfSymbols) * in the outputs. This is the case for symbols that are produced by a direct or indirect * source of the node you're applying this to. */ + public PlanMatchPattern withExactOutputs(String... expectedAliases) + { + return withExactOutputs(ImmutableList.copyOf(expectedAliases)); + } + public PlanMatchPattern withExactOutputs(List expectedAliases) { matchers.add(new StrictSymbolsMatcher(actualOutputs(), expectedAliases)); @@ -394,18 +474,34 @@ public PlanMatchPattern withExactOutputs(List expectedAliases) * the alias is *not* known when the Matcher is run, and so you need to match by what * is being assigned to it. */ + public PlanMatchPattern withExactAssignedOutputs(RvalueMatcher... expectedAliases) + { + return withExactAssignedOutputs(ImmutableList.copyOf(expectedAliases)); + } + public PlanMatchPattern withExactAssignedOutputs(Collection expectedAliases) { matchers.add(new StrictAssignedSymbolsMatcher(actualOutputs(), expectedAliases)); return this; } + public PlanMatchPattern withExactAssignments(RvalueMatcher... expectedAliases) + { + return withExactAssignments(ImmutableList.copyOf(expectedAliases)); + } + public PlanMatchPattern withExactAssignments(Collection expectedAliases) { matchers.add(new StrictAssignedSymbolsMatcher(actualAssignments(), expectedAliases)); return this; } + public PlanMatchPattern withCost(PlanNodeCost cost) + { + matchers.add(new PlanCostMatcher(cost)); + return this; + } + public static RvalueMatcher columnReference(String tableName, String columnName) { return new ColumnReference(tableName, columnName); @@ -416,6 +512,11 @@ public static ExpressionMatcher expression(String expression) return new ExpressionMatcher(expression); } + public PlanMatchPattern withOutputs(String... aliases) + { + return withOutputs(ImmutableList.copyOf(aliases)); + } + public PlanMatchPattern withOutputs(List aliases) { matchers.add(new OutputMatcher(aliases)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java index 3bdc4a02ade76..42f5dc64aa6dd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java @@ -14,15 +14,20 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; import java.util.List; +import java.util.Map; import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; import static com.facebook.presto.sql.planner.assertions.MatchResult.match; @@ -31,15 +36,19 @@ import static java.util.Objects.requireNonNull; final class PlanMatchingVisitor - extends PlanVisitor + extends PlanVisitor { private final Metadata metadata; private final Session session; + private final Map planCost; + private final Lookup lookup; - PlanMatchingVisitor(Session session, Metadata metadata) + PlanMatchingVisitor(Session session, Metadata metadata, Map planCost, Lookup lookup) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); + this.planCost = requireNonNull(planCost, "planCost is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); } @Override @@ -78,6 +87,16 @@ public MatchResult visitProject(ProjectNode node, PlanMatchPattern pattern) return match(result.getAliases().replaceAssignments(node.getAssignments())); } + @Override + public MatchResult visitGroupReference(GroupReference node, PlanMatchPattern pattern) + { + MatchResult match = lookup.resolve(node).accept(this, pattern); + if (match.isMatch()) { + return match; + } + return visitPlan(node, pattern); + } + @Override protected MatchResult visitPlan(PlanNode node, PlanMatchPattern pattern) { @@ -104,7 +123,7 @@ protected MatchResult visitPlan(PlanNode node, PlanMatchPattern pattern) // Try upMatching this node with the the aliases gathered from the source nodes. SymbolAliases allSourceAliases = sourcesMatch.getAliases(); - MatchResult matchResult = pattern.detailMatches(node, session, metadata, allSourceAliases); + MatchResult matchResult = pattern.detailMatches(node, planCost.get(node.getId()), session, metadata, allSourceAliases); if (matchResult.isMatch()) { checkState(result == NO_MATCH, format("Ambiguous match on node %s", node)); result = match(allSourceAliases.withNewAliases(matchResult.getAliases())); @@ -129,7 +148,7 @@ private MatchResult matchLeaf(PlanNode node, PlanMatchPattern pattern, List new IllegalStateException(format("missing expression for alias %s", alias))); + } + + public Optional getOptional(String alias) + { + alias = toKey(alias); + SymbolReference result = map.get(alias); + return Optional.ofNullable(result); } private static String toKey(String alias) @@ -206,7 +211,6 @@ public Builder put(String alias, SymbolReference symbolReference) } checkState(!bindings.containsKey(alias), "Alias '%s' already bound to expression '%s'. Tried to rebind to '%s'", alias, bindings.get(alias), symbolReference); - checkState(!bindings.values().contains(symbolReference), "Expression '%s' is already bound in %s. Tried to rebind as '%s'.", symbolReference, bindings, alias); bindings.put(alias, symbolReference); return this; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java index 925a64e8e3793..bba682f3d2d0b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolCardinalityMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -36,7 +37,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) { return new MatchResult(node.getOutputSymbols().size() == numberOfSymbols); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java index e6f1541657649..65b39ad9cc47d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/TableScanMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.TableMetadata; import com.facebook.presto.spi.ColumnHandle; @@ -54,7 +55,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); @@ -63,7 +64,7 @@ public MatchResult detailMatches(PlanNode node, Session session, Metadata metada String actualTableName = tableMetadata.getTable().getTableName(); return new MatchResult( expectedTableName.equalsIgnoreCase(actualTableName) && - domainMatches(tableScanNode, session, metadata)); + domainMatches(tableScanNode, session, metadata)); } private boolean domainMatches(TableScanNode tableScanNode, Session session, Metadata metadata) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java index fa4768a49cbae..f194f1587fe05 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ValuesMatcher.java @@ -14,32 +14,70 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import java.util.List; +import java.util.Map; import java.util.Optional; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + public class ValuesMatcher - implements RvalueMatcher + implements Matcher { - int outputIndex; + private final Map outputSymbolAliases; + private final Optional expectedOutputSymbolCount; + private final Optional>> expectedRows; - public ValuesMatcher(int outputIndex) + public ValuesMatcher( + Map outputSymbolAliases, + Optional expectedOutputSymbolCount, + Optional>> expectedRows) { - this.outputIndex = outputIndex; + this.outputSymbolAliases = ImmutableMap.copyOf(outputSymbolAliases); + this.expectedOutputSymbolCount = requireNonNull(expectedOutputSymbolCount, "expectedOutputSymbolCount is null"); + this.expectedRows = requireNonNull(expectedRows, "expectedRows is null"); } @Override - public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public boolean shapeMatches(PlanNode node) { - if (!(node instanceof ValuesNode)) { - return Optional.empty(); - } + return (node instanceof ValuesNode) && + expectedOutputSymbolCount.map(Integer.valueOf(node.getOutputSymbols().size())::equals).orElse(true); + } + @Override + public MatchResult detailMatches(PlanNode node, PlanNodeCost planNodeCost, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); ValuesNode valuesNode = (ValuesNode) node; - return Optional.of(valuesNode.getOutputSymbols().get(outputIndex)); + if (!expectedRows.map(rows -> rows.equals(valuesNode.getRows())).orElse(true)) { + return NO_MATCH; + } + + return match(SymbolAliases.builder() + .putAll(Maps.transformValues(outputSymbolAliases, index -> valuesNode.getOutputSymbols().get(index).toSymbolReference())) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("outputSymbolAliases", outputSymbolAliases) + .add("expectedOutputSymbolCount", expectedOutputSymbolCount) + .add("expectedRows", expectedRows) + .toString(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index b3204da6e1cb0..6d291969351ad 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -39,7 +40,7 @@ public boolean shapeMatches(PlanNode node) } @Override - public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) + public MatchResult detailMatches(PlanNode node, PlanNodeCost cost, Session session, Metadata metadata, SymbolAliases symbolAliases) { checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java new file mode 100644 index 0000000000000..6097f46a177f8 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleStore.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.DummyMetadata; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.BooleanLiteral; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.stream.Collectors.toList; +import static org.testng.Assert.assertEquals; + +public class TestRuleStore +{ + private final PlanBuilder planBuilder = new PlanBuilder(new PlanNodeIdAllocator(), new DummyMetadata()); + + @Test + public void test() + { + Rule projectRule1 = new NoOpRule(Pattern.node(ProjectNode.class)); + Rule projectRule2 = new NoOpRule(Pattern.node(ProjectNode.class)); + Rule filterRule = new NoOpRule(Pattern.node(FilterNode.class)); + Rule anyRule = new NoOpRule(Pattern.any()); + + RuleStore ruleStore = RuleStore.builder() + .register(projectRule1) + .register(projectRule2) + .register(filterRule) + .register(anyRule) + .build(); + + ProjectNode projectNode = planBuilder.project(Assignments.of(), planBuilder.values()); + FilterNode filterNode = planBuilder.filter(BooleanLiteral.TRUE_LITERAL, planBuilder.values()); + ValuesNode valuesNode = planBuilder.values(); + + assertEquals( + ruleStore.getCandidates(projectNode).collect(toList()), + ImmutableList.of(projectRule1, projectRule2, anyRule)); + assertEquals( + ruleStore.getCandidates(filterNode).collect(toList()), + ImmutableList.of(filterRule, anyRule)); + assertEquals( + ruleStore.getCandidates(valuesNode).collect(toList()), + ImmutableList.of(anyRule)); + } + + private static class NoOpRule + implements Rule + { + private final Pattern pattern; + + private NoOpRule(Pattern pattern) + { + this.pattern = pattern; + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + return Optional.empty(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("pattern", pattern) + .toString(); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java index c466e2bcb4e25..3905f8044fa6c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java @@ -15,15 +15,13 @@ import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.Optional; @@ -44,45 +42,30 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestAddIntermediateAggregations + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testBasic() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .matches( aggregation( @@ -123,21 +106,21 @@ public void testNoInputCount() ExpectedValueProvider rawInputCount = PlanMatchPattern.functionCall("count", false, ImmutableList.of()); ExpectedValueProvider partialInputCount = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(*)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(*)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .matches( aggregation( @@ -176,13 +159,13 @@ public void testMultipleExchanges() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, @@ -190,9 +173,9 @@ public void testMultipleExchanges() ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT))))))); + p.values(p.symbol("a"))))))); })) .matches( aggregation( @@ -230,21 +213,21 @@ public void testMultipleExchanges() @Test public void testSessionDisable() { - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "false") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .doesNotFire(); } @@ -254,21 +237,21 @@ public void testNoLocalParallel() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "1") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .matches( aggregation( @@ -297,21 +280,21 @@ public void testNoLocalParallel() @Test public void testWithGroups() { - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { - af.groupingSets(ImmutableList.of(ImmutableList.of(p.symbol("c", BIGINT)))) + af.addGroupingSet(p.symbol("c")) .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.groupingSets(ImmutableList.of(ImmutableList.of(p.symbol("b", BIGINT)))) + p.aggregation(ap -> ap.addGroupingSet(p.symbol("b")) .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT)))))); + p.values(p.symbol("a")))))); })) .doesNotFire(); } @@ -321,23 +304,23 @@ public void testInterimProject() { ExpectedValueProvider aggregationPattern = PlanMatchPattern.functionCall("count", false, ImmutableList.of(anySymbol())); - tester.assertThat(new AddIntermediateAggregations()) + tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c", BIGINT), expression("count(b)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.project( - Assignments.identity(p.symbol("b", BIGINT)), + Assignments.identity(p.symbol("b")), p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) - .addAggregation(p.symbol("b", BIGINT), expression("count(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( - p.values(p.symbol("a", BIGINT))))))); + p.values(p.symbol("a"))))))); })) .matches( aggregation( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java similarity index 60% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateCrossJoins.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 29357306d8717..e46826bfb567e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -11,13 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.optimizations; +package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; @@ -29,8 +33,15 @@ import java.util.Arrays; import java.util.Optional; +import java.util.function.Function; -import static com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins.isOriginalOrder; +import static com.facebook.presto.SystemSessionProperties.REORDER_JOINS; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; +import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -41,8 +52,54 @@ @Test(singleThreaded = true) public class TestEliminateCrossJoins + extends BaseRuleTest { - PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + + @Test + public void testEliminateCrossJoin() + { + tester().assertThat(new EliminateCrossJoins()) + .setSystemProperty(REORDER_JOINS, "true") + .on(crossJoinAndJoin(INNER)) + .matches( + join(INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("cySymbol"), new Symbol("bySymbol"))), + join(INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("axSymbol"), new Symbol("cxSymbol"))), + any(), + any() + ), + any() + ) + ); + } + + @Test + public void testRetainOutgoingGroupReferences() + { + tester().assertThat(new EliminateCrossJoins()) + .setSystemProperty(REORDER_JOINS, "true") + .on(crossJoinAndJoin(INNER)) + .matches( + node(JoinNode.class, + node(JoinNode.class, + node(GroupReference.class), + node(GroupReference.class) + ), + node(GroupReference.class) + ) + ); + } + + @Test + public void testDoNotReorderOuterJoin() + { + tester().assertThat(new EliminateCrossJoins()) + .setSystemProperty(REORDER_JOINS, "true") + .on(crossJoinAndJoin(JoinNode.Type.LEFT)) + .doesNotFire(); + } @Test public void testIsOriginalOrder() @@ -55,8 +112,8 @@ public void testIsOriginalOrder() public void testJoinOrder() { PlanNode plan = - join( - join( + joinNode( + joinNode( values(symbol("a")), values(symbol("b"))), values(symbol("c")), @@ -66,7 +123,7 @@ public void testJoinOrder() JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); assertEquals( - EliminateCrossJoins.getJoinOrder(joinGraph), + getJoinOrder(joinGraph), ImmutableList.of(0, 2, 1)); } @@ -74,8 +131,8 @@ public void testJoinOrder() public void testJoinOrderWithRealCrossJoin() { PlanNode leftPlan = - join( - join( + joinNode( + joinNode( values(symbol("a")), values(symbol("b"))), values(symbol("c")), @@ -83,20 +140,20 @@ public void testJoinOrderWithRealCrossJoin() symbol("c"), symbol("b")); PlanNode rightPlan = - join( - join( + joinNode( + joinNode( values(symbol("x")), values(symbol("y"))), values(symbol("z")), symbol("x"), symbol("z"), symbol("z"), symbol("y")); - PlanNode plan = join(leftPlan, rightPlan); + PlanNode plan = joinNode(leftPlan, rightPlan); JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); assertEquals( - EliminateCrossJoins.getJoinOrder(joinGraph), + getJoinOrder(joinGraph), ImmutableList.of(0, 2, 1, 3, 5, 4)); } @@ -104,8 +161,8 @@ public void testJoinOrderWithRealCrossJoin() public void testJoinOrderWithMultipleEdgesBetweenNodes() { PlanNode plan = - join( - join( + joinNode( + joinNode( values(symbol("a")), values(symbol("b1"), symbol("b2"))), values(symbol("c1"), symbol("c2")), @@ -116,7 +173,7 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); assertEquals( - EliminateCrossJoins.getJoinOrder(joinGraph), + getJoinOrder(joinGraph), ImmutableList.of(0, 2, 1)); } @@ -124,8 +181,8 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() public void testDonNotChangeOrderWithoutCrossJoin() { PlanNode plan = - join( - join( + joinNode( + joinNode( values(symbol("a")), values(symbol("b")), symbol("a"), symbol("b")), @@ -135,7 +192,7 @@ public void testDonNotChangeOrderWithoutCrossJoin() JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); assertEquals( - EliminateCrossJoins.getJoinOrder(joinGraph), + getJoinOrder(joinGraph), ImmutableList.of(0, 1, 2)); } @@ -143,8 +200,8 @@ public void testDonNotChangeOrderWithoutCrossJoin() public void testDoNotReorderCrossJoins() { PlanNode plan = - join( - join( + joinNode( + joinNode( values(symbol("a")), values(symbol("b"))), values(symbol("c")), @@ -153,7 +210,7 @@ public void testDoNotReorderCrossJoins() JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); assertEquals( - EliminateCrossJoins.getJoinOrder(joinGraph), + getJoinOrder(joinGraph), ImmutableList.of(0, 1, 2)); } @@ -161,9 +218,9 @@ public void testDoNotReorderCrossJoins() public void testGiveUpOnNonIdentityProjections() { PlanNode plan = - join( - project( - join( + joinNode( + projectNode( + joinNode( values(symbol("a1")), values(symbol("b"))), symbol("a2"), @@ -175,7 +232,28 @@ public void testGiveUpOnNonIdentityProjections() assertEquals(JoinGraph.buildFrom(plan).size(), 2); } - private PlanNode project(PlanNode source, String symbol, Expression expression) + private Function crossJoinAndJoin(JoinNode.Type secondJoinType) + { + return p -> { + Symbol axSymbol = p.symbol("axSymbol"); + Symbol bySymbol = p.symbol("bySymbol"); + Symbol cxSymbol = p.symbol("cxSymbol"); + Symbol cySymbol = p.symbol("cySymbol"); + + // (a inner join b) inner join c on c.x = a.x and c.y = b.y + return p.join(INNER, + p.join(secondJoinType, + p.values(axSymbol), + p.values(bySymbol) + ), + p.values(cxSymbol, cySymbol), + new EquiJoinClause(cxSymbol, axSymbol), + new EquiJoinClause(cySymbol, bySymbol) + ); + }; + } + + private PlanNode projectNode(PlanNode source, String symbol, Expression expression) { return new ProjectNode( idAllocator.getNextId(), @@ -188,7 +266,7 @@ private String symbol(String name) return name; } - private JoinNode join(PlanNode left, PlanNode right, String... symbols) + private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) { checkArgument(symbols.length % 2 == 0); ImmutableList.Builder criteria = ImmutableList.builder(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestEvaluateZeroLimit.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java similarity index 69% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestEvaluateZeroLimit.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java index 133a783005118..0d43ee7179891 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestEvaluateZeroLimit.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroLimit.java @@ -11,47 +11,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.iterative.rule.test; +package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroLimit; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expressions; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestEvaluateZeroLimit + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new EvaluateZeroLimit()) + tester().assertThat(new EvaluateZeroLimit()) .on(p -> p.limit( 1, - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -59,14 +41,14 @@ public void testDoesNotFire() public void test() throws Exception { - tester.assertThat(new EvaluateZeroLimit()) + tester().assertThat(new EvaluateZeroLimit()) .on(p -> p.limit( 0, p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( expressions("1", "10"), expressions("2", "11")))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestEvaluateZeroSample.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java similarity index 71% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestEvaluateZeroSample.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java index a00a145270800..7a27936ee4468 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestEvaluateZeroSample.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEvaluateZeroSample.java @@ -11,49 +11,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.iterative.rule.test; +package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroSample; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.SampleNode.Type; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expressions; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestEvaluateZeroSample + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new EvaluateZeroSample()) + tester().assertThat(new EvaluateZeroSample()) .on(p -> p.sample( 0.15, Type.BERNOULLI, - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -61,7 +43,7 @@ public void testDoesNotFire() public void test() throws Exception { - tester.assertThat(new EvaluateZeroSample()) + tester().assertThat(new EvaluateZeroSample()) .on(p -> p.sample( 0, @@ -69,7 +51,7 @@ public void test() p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( expressions("1", "10"), expressions("2", "11")))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java index c394c450c3e6e..1f639a2a4e41a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java @@ -15,58 +15,40 @@ import com.facebook.presto.sql.planner.assertions.ExpressionMatcher; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestInlineProjections + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void test() { - tester.assertThat(new InlineProjections()) + tester().assertThat(new InlineProjections()) .on(p -> p.project( Assignments.builder() - .put(p.symbol("identity", BIGINT), expression("symbol")) // identity - .put(p.symbol("multi_complex_1", BIGINT), expression("complex + 1")) // complex expression referenced multiple times - .put(p.symbol("multi_complex_2", BIGINT), expression("complex + 2")) // complex expression referenced multiple times - .put(p.symbol("multi_literal_1", BIGINT), expression("literal + 1")) // literal referenced multiple times - .put(p.symbol("multi_literal_2", BIGINT), expression("literal + 2")) // literal referenced multiple times - .put(p.symbol("single_complex", BIGINT), expression("complex_2 + 2")) // complex expression reference only once - .put(p.symbol("try", BIGINT), expression("try(complex / literal)")) + .put(p.symbol("identity"), expression("symbol")) // identity + .put(p.symbol("multi_complex_1"), expression("complex + 1")) // complex expression referenced multiple times + .put(p.symbol("multi_complex_2"), expression("complex + 2")) // complex expression referenced multiple times + .put(p.symbol("multi_literal_1"), expression("literal + 1")) // literal referenced multiple times + .put(p.symbol("multi_literal_2"), expression("literal + 2")) // literal referenced multiple times + .put(p.symbol("single_complex"), expression("complex_2 + 2")) // complex expression reference only once + .put(p.symbol("try"), expression("try(complex / literal)")) .build(), p.project(Assignments.builder() - .put(p.symbol("symbol", BIGINT), expression("x")) - .put(p.symbol("complex", BIGINT), expression("x * 2")) - .put(p.symbol("literal", BIGINT), expression("1")) - .put(p.symbol("complex_2", BIGINT), expression("x - 1")) + .put(p.symbol("symbol"), expression("x")) + .put(p.symbol("complex"), expression("x * 2")) + .put(p.symbol("literal"), expression("1")) + .put(p.symbol("complex_2"), expression("x - 1")) .build(), - p.values(p.symbol("x", BIGINT))))) + p.values(p.symbol("x"))))) .matches( project( ImmutableMap.builder() @@ -89,13 +71,13 @@ public void test() public void testIdentityProjections() throws Exception { - tester.assertThat(new InlineProjections()) + tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.of(p.symbol("output", BIGINT), expression("value")), + Assignments.of(p.symbol("output"), expression("value")), p.project( - Assignments.identity(p.symbol("value", BIGINT)), - p.values(p.symbol("value", BIGINT))))) + Assignments.identity(p.symbol("value")), + p.values(p.symbol("value"))))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java new file mode 100644 index 0000000000000..7f0bd47bb9e55 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -0,0 +1,209 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.Signature; +import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.sql.tree.Window; +import com.facebook.presto.sql.tree.WindowFrame; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.tree.FrameBound.Type.CURRENT_ROW; +import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; + +public class TestMergeAdjacentWindows + extends BaseRuleTest +{ + private static final WindowNode.Frame frame = new WindowNode.Frame(WindowFrame.Type.RANGE, UNBOUNDED_PRECEDING, + Optional.empty(), CURRENT_ROW, Optional.empty()); + private static final Signature signature = new Signature( + "avg", + FunctionKind.WINDOW, + ImmutableList.of(), + ImmutableList.of(), + DOUBLE.getTypeSignature(), + ImmutableList.of(DOUBLE.getTypeSignature()), + false); + + @Test + public void testPlanWithoutWindowNode() + throws Exception + { + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> p.values(p.symbol("a"))) + .doesNotFire(); + } + + @Test + public void testPlanWithSingleWindowNode() + throws Exception + { + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", "a")), + p.values(p.symbol("a")))) + .doesNotFire(); + } + + @Test + public void testDistinctAdjacentWindowSpecifications() + { + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", "a")), + p.window( + newWindowNodeSpecification(p, "b"), + ImmutableMap.of(p.symbol("sum_1"), newWindowNodeFunction("sum", "b")), + p.values(p.symbol("b")) + ) + )) + .doesNotFire(); + } + + @Test + public void testNonWindowIntermediateNode() + { + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("lag_1"), newWindowNodeFunction("lag", "a", "ONE")), + p.project( + Assignments.copyOf(ImmutableMap.of(p.symbol("ONE"), p.expression("CAST(1 AS bigint)"))), + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", "a")), + p.values(p.symbol("a")) + ) + ) + )) + .doesNotFire(); + } + + @Test + public void testDependentAdjacentWindowsIdenticalSpecifications() + throws Exception + { + Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); + + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", windowA, "avg_2")), + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction("avg", windowA, "a")), + p.values(p.symbol("a")) + ) + )) + .doesNotFire(); + } + + @Test + public void testDependentAdjacentWindowsDistinctSpecifications() + throws Exception + { + Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); + + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", windowA, "avg_2")), + p.window( + newWindowNodeSpecification(p, "b"), + ImmutableMap.of(p.symbol("avg_2"), newWindowNodeFunction("avg", windowA, "a")), + p.values(p.symbol("a"), p.symbol("b")) + ) + )) + .doesNotFire(); + } + + @Test + public void testIdenticalAdjacentWindowSpecifications() + throws Exception + { + String columnAAlias = "ALIAS_A"; + + ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); + + Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); + + tester().assertThat(new MergeAdjacentWindows()) + .on(p -> + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("avg_1"), newWindowNodeFunction("avg", windowA, "a")), + p.window( + newWindowNodeSpecification(p, "a"), + ImmutableMap.of(p.symbol("sum_1"), newWindowNodeFunction("sum", windowA, "a")), + p.values(p.symbol("a")) + ) + )) + .matches(window( + specificationA, + ImmutableList.of( + functionCall("avg", Optional.empty(), ImmutableList.of(columnAAlias)), + functionCall("sum", Optional.empty(), ImmutableList.of(columnAAlias))), + values(ImmutableMap.of(columnAAlias, 0)))); + } + + private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName) + { + return new WindowNode.Specification(ImmutableList.of(planBuilder.symbol(symbolName, BIGINT)), ImmutableList.of(), ImmutableMap.of()); + } + + private WindowNode.Function newWindowNodeFunction(String functionName, String... symbols) + { + return new WindowNode.Function( + new FunctionCall( + QualifiedName.of(functionName), + Arrays.stream(symbols).map(symbol -> new SymbolReference(symbol)).collect(Collectors.toList())), + signature, + frame); + } + + private WindowNode.Function newWindowNodeFunction(String functionName, Optional window, String symbolName) + { + return new WindowNode.Function( + new FunctionCall(QualifiedName.of(functionName), window, false, ImmutableList.of(new SymbolReference(symbolName))), + signature, + frame); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java index 6f9ff8414fb41..9ce71f2c884fa 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeFilters.java @@ -13,43 +13,25 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestMergeFilters + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void test() { - tester.assertThat(new MergeFilters()) + tester().assertThat(new MergeFilters()) .on(p -> p.filter(expression("b > 44"), p.filter(expression("a < 42"), - p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) + p.values(p.symbol("a"), p.symbol("b"))))) .matches(filter("(a < 42) AND (b > 44)", values(ImmutableMap.of("a", 0, "b", 1)))); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java new file mode 100644 index 0000000000000..32816c19b1135 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneCrossJoinColumns + extends BaseRuleTest +{ + @Test + public void testLeftInputNotReferenced() + { + tester().assertThat(new PruneCrossJoinColumns()) + .on(p -> buildProjectedCrossJoin(p, symbol -> symbol.getName().equals("rightValue"))) + .matches( + strictProject( + ImmutableMap.of("rightValue", PlanMatchPattern.expression("rightValue")), + join( + JoinNode.Type.INNER, + ImmutableList.of(), + Optional.empty(), + strictProject( + ImmutableMap.of(), + values(ImmutableList.of("leftValue"))), + values(ImmutableList.of("rightValue"))) + .withExactOutputs("rightValue"))); + } + + @Test + public void testRightInputNotReferenced() + { + tester().assertThat(new PruneCrossJoinColumns()) + .on(p -> buildProjectedCrossJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .matches( + strictProject( + ImmutableMap.of("leftValue", PlanMatchPattern.expression("leftValue")), + join( + JoinNode.Type.INNER, + ImmutableList.of(), + Optional.empty(), + values(ImmutableList.of("leftValue")), + strictProject( + ImmutableMap.of(), + values(ImmutableList.of("rightValue")))) + .withExactOutputs("leftValue"))); + } + + @Test + public void testAllInputsReferenced() + { + tester().assertThat(new PruneCrossJoinColumns()) + .on(p -> buildProjectedCrossJoin(p, Predicates.alwaysTrue())) + .doesNotFire(); + } + + private static PlanNode buildProjectedCrossJoin(PlanBuilder p, Predicate projectionFilter) + { + Symbol leftValue = p.symbol("leftValue"); + Symbol rightValue = p.symbol("rightValue"); + List outputs = ImmutableList.of(leftValue, rightValue); + return p.project( + Assignments.identity( + outputs.stream() + .filter(projectionFilter) + .collect(toImmutableList())), + p.join( + JoinNode.Type.INNER, + p.values(leftValue), + p.values(rightValue), + ImmutableList.of(), + outputs, + Optional.empty(), + Optional.empty(), + Optional.empty())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java new file mode 100644 index 0000000000000..935ee820db41a --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinChildrenColumns.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneJoinChildrenColumns + extends BaseRuleTest +{ + @Test + public void testNotAllInputsRereferenced() + { + tester().assertThat(new PruneJoinChildrenColumns()) + .on(p -> buildJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .matches( + join( + JoinNode.Type.INNER, + ImmutableList.of(equiJoinClause("leftKey", "rightKey")), + Optional.of("leftValue > 5"), + values("leftKey", "leftKeyHash", "leftValue"), + strictProject( + ImmutableMap.of( + "rightKey", PlanMatchPattern.expression("rightKey"), + "rightKeyHash", PlanMatchPattern.expression("rightKeyHash")), + values("rightKey", "rightKeyHash", "rightValue")))); + } + + @Test + public void testAllInputsReferenced() + { + tester().assertThat(new PruneJoinChildrenColumns()) + .on(p -> buildJoin(p, Predicates.alwaysTrue())) + .doesNotFire(); + } + + @Test + public void testCrossJoinDoesNotFire() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> { + Symbol leftValue = p.symbol("leftValue"); + Symbol rightValue = p.symbol("rightValue"); + return p.join( + JoinNode.Type.INNER, + p.values(leftValue), + p.values(rightValue), + ImmutableList.of(), + ImmutableList.of(leftValue, rightValue), + Optional.empty(), + Optional.empty(), + Optional.empty()); + }) + .doesNotFire(); + } + + private static PlanNode buildJoin(PlanBuilder p, Predicate joinOutputFilter) + { + Symbol leftKey = p.symbol("leftKey"); + Symbol leftKeyHash = p.symbol("leftKeyHash"); + Symbol leftValue = p.symbol("leftValue"); + Symbol rightKey = p.symbol("rightKey"); + Symbol rightKeyHash = p.symbol("rightKeyHash"); + Symbol rightValue = p.symbol("rightValue"); + List outputs = ImmutableList.of(leftValue, rightValue); + return p.join( + JoinNode.Type.INNER, + p.values(leftKey, leftKeyHash, leftValue), + p.values(rightKey, rightKeyHash, rightValue), + ImmutableList.of(new JoinNode.EquiJoinClause(leftKey, rightKey)), + outputs.stream() + .filter(joinOutputFilter) + .collect(toImmutableList()), + Optional.of(expression("leftValue > 5")), + Optional.of(leftKeyHash), + Optional.of(rightKeyHash)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java new file mode 100644 index 0000000000000..7bfcc8337221d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneJoinColumns.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneJoinColumns + extends BaseRuleTest +{ + @Test + public void testNotAllOutputsReferenced() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> buildProjectedJoin(p, symbol -> symbol.getName().equals("rightValue"))) + .matches( + strictProject( + ImmutableMap.of("rightValue", PlanMatchPattern.expression("rightValue")), + join( + JoinNode.Type.INNER, + ImmutableList.of(equiJoinClause("leftKey", "rightKey")), + Optional.empty(), + values(ImmutableList.of("leftKey", "leftValue")), + values(ImmutableList.of("rightKey", "rightValue"))) + .withExactOutputs("rightValue"))); + } + + @Test + public void testAllInputsReferenced() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> buildProjectedJoin(p, Predicates.alwaysTrue())) + .doesNotFire(); + } + + @Test + public void testCrossJoinDoesNotFire() + { + tester().assertThat(new PruneJoinColumns()) + .on(p -> { + Symbol leftValue = p.symbol("leftValue"); + Symbol rightValue = p.symbol("rightValue"); + return p.project( + Assignments.of(), + p.join( + JoinNode.Type.INNER, + p.values(leftValue), + p.values(rightValue), + ImmutableList.of(), + ImmutableList.of(leftValue, rightValue), + Optional.empty(), + Optional.empty(), + Optional.empty())); + }) + .doesNotFire(); + } + + private static PlanNode buildProjectedJoin(PlanBuilder p, Predicate projectionFilter) + { + Symbol leftKey = p.symbol("leftKey"); + Symbol leftValue = p.symbol("leftValue"); + Symbol rightKey = p.symbol("rightKey"); + Symbol rightValue = p.symbol("rightValue"); + List outputs = ImmutableList.of(leftKey, leftValue, rightKey, rightValue); + return p.project( + Assignments.identity( + outputs.stream() + .filter(projectionFilter) + .collect(toImmutableList())), + p.join( + JoinNode.Type.INNER, + p.values(leftKey, leftValue), + p.values(rightKey, rightValue), + ImmutableList.of(new JoinNode.EquiJoinClause(leftKey, rightKey)), + outputs, + Optional.empty(), + Optional.empty(), + Optional.empty())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java new file mode 100644 index 0000000000000..3af9467e90a10 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.markDistinct; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneMarkDistinctColumns + extends BaseRuleTest +{ + @Test + public void testMarkerSymbolNotReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol key2 = p.symbol("key2"); + Symbol mark = p.symbol("mark"); + Symbol unused = p.symbol("unused"); + return p.project( + Assignments.of(key2, key.toSymbolReference()), + p.markDistinct(mark, ImmutableList.of(key), p.values(key, unused))); + }) + .matches( + strictProject( + ImmutableMap.of("key2", expression("key")), + values(ImmutableList.of("key", "unused")))); + } + + @Test + public void testSourceSymbolNotReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol mark = p.symbol("mark"); + Symbol hash = p.symbol("hash"); + Symbol unused = p.symbol("unused"); + return p.project( + Assignments.identity(mark), + p.markDistinct( + mark, + ImmutableList.of(key), + hash, + p.values(key, hash, unused))); + }) + .matches( + strictProject( + ImmutableMap.of("mark", expression("mark")), + markDistinct("mark", ImmutableList.of("key"), "hash", + strictProject( + ImmutableMap.of( + "key", expression("key"), + "hash", expression("hash")), + values(ImmutableList.of("key", "hash", "unused")))))); + } + + @Test + public void testKeySymbolNotReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol mark = p.symbol("mark"); + return p.project( + Assignments.identity(mark), + p.markDistinct(mark, ImmutableList.of(key), p.values(key))); + }) + .doesNotFire(); + } + + @Test + public void testAllOutputsReferenced() + throws Exception + { + tester().assertThat(new PruneMarkDistinctColumns()) + .on(p -> + { + Symbol key = p.symbol("key"); + Symbol mark = p.symbol("mark"); + return p.project( + Assignments.identity(key, mark), + p.markDistinct(mark, ImmutableList.of(key), p.values(key))); + }) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java new file mode 100644 index 0000000000000..4dc7a637e27ac --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinColumns.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneSemiJoinColumns + extends BaseRuleTest +{ + @Test + public void testSemiJoinNotNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("leftValue"))) + .matches( + strictProject( + ImmutableMap.of("leftValue", expression("leftValue")), + values("leftKey", "leftKeyHash", "leftValue"))); + } + + @Test + public void testAllColumnsNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> true)) + .doesNotFire(); + } + + @Test + public void testKeysNotNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> (symbol.getName().equals("leftValue") || symbol.getName().equals("match")))) + .doesNotFire(); + } + + @Test + public void testValueNotNeeded() + { + tester().assertThat(new PruneSemiJoinColumns()) + .on(p -> buildProjectedSemiJoin(p, symbol -> symbol.getName().equals("match"))) + .matches( + strictProject( + ImmutableMap.of("match", expression("match")), + semiJoin("leftKey", "rightKey", "match", + strictProject( + ImmutableMap.of( + "leftKey", expression("leftKey"), + "leftKeyHash", expression("leftKeyHash")), + values("leftKey", "leftKeyHash", "leftValue")), + values("rightKey")))); + } + + private static PlanNode buildProjectedSemiJoin(PlanBuilder p, Predicate projectionFilter) + { + Symbol match = p.symbol("match"); + Symbol leftKey = p.symbol("leftKey"); + Symbol leftKeyHash = p.symbol("leftKeyHash"); + Symbol leftValue = p.symbol("leftValue"); + Symbol rightKey = p.symbol("rightKey"); + List outputs = ImmutableList.of(match, leftKey, leftKeyHash, leftValue); + return p.project( + Assignments.identity( + outputs.stream() + .filter(projectionFilter) + .collect(toImmutableList())), + p.semiJoin( + leftKey, + rightKey, + match, + Optional.of(leftKeyHash), + Optional.empty(), + p.values(leftKey, leftKeyHash, leftValue), + p.values(rightKey))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java new file mode 100644 index 0000000000000..d64c103f1e96c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneSemiJoinFilteringSourceColumns.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.function.Predicate; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TestPruneSemiJoinFilteringSourceColumns + extends BaseRuleTest +{ + @Test + public void testNotAllColumnsReferenced() + { + tester().assertThat(new PruneSemiJoinFilteringSourceColumns()) + .on(p -> buildSemiJoin(p, symbol -> true)) + .matches( + semiJoin("leftKey", "rightKey", "match", + values("leftKey"), + strictProject( + ImmutableMap.of( + "rightKey", expression("rightKey"), + "rightKeyHash", expression("rightKeyHash")), + values("rightKey", "rightKeyHash", "rightValue")))); + } + + @Test + public void testAllColumnsNeeded() + { + tester().assertThat(new PruneSemiJoinFilteringSourceColumns()) + .on(p -> buildSemiJoin(p, symbol -> !symbol.getName().equals("rightValue"))) + .doesNotFire(); + } + + private static PlanNode buildSemiJoin(PlanBuilder p, Predicate filteringSourceSymbolFilter) + { + Symbol match = p.symbol("match"); + Symbol leftKey = p.symbol("leftKey"); + Symbol rightKey = p.symbol("rightKey"); + Symbol rightKeyHash = p.symbol("rightKeyHash"); + Symbol rightValue = p.symbol("rightValue"); + List filteringSourceSymbols = ImmutableList.of(rightKey, rightKeyHash, rightValue); + return p.semiJoin( + leftKey, + rightKey, + match, + Optional.empty(), + Optional.of(rightKeyHash), + p.values(leftKey), + p.values( + filteringSourceSymbols.stream() + .filter(filteringSourceSymbolFilter) + .collect(toImmutableList()), + ImmutableList.of())); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java new file mode 100644 index 0000000000000..18957faf46bc6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.DateType.DATE; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; + +public class TestPruneTableScanColumns + extends BaseRuleTest +{ + @Test + public void testNotAllOutputsReferenced() + { + tester().assertThat(new PruneTableScanColumns()) + .on(p -> + { + Symbol orderdate = p.symbol("orderdate", DATE); + Symbol totalprice = p.symbol("totalprice", DOUBLE); + return p.project( + Assignments.of(p.symbol("x"), totalprice.toSymbolReference()), + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("local", "orders", TINY_SCALE_FACTOR)), + ImmutableList.of(orderdate, totalprice), + ImmutableMap.of( + orderdate, new TpchColumnHandle(orderdate.getName(), DATE), + totalprice, new TpchColumnHandle(totalprice.getName(), DOUBLE)))); + }) + .matches( + strictProject( + ImmutableMap.of("x_", PlanMatchPattern.expression("totalprice_")), + strictTableScan("orders", ImmutableMap.of("totalprice_", "totalprice")))); + } + + @Test + public void testAllOutputsReferenced() + { + tester().assertThat(new PruneTableScanColumns()) + .on(p -> + p.project( + Assignments.of(p.symbol("y"), expression("x")), + p.tableScan( + ImmutableList.of(p.symbol("x")), + ImmutableMap.of(p.symbol("x"), new TestingColumnHandle("x"))))) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java new file mode 100644 index 0000000000000..3f78247ffe5f3 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; + +public class TestPruneValuesColumns + extends BaseRuleTest +{ + @Test + public void testNotAllOutputsReferenced() + throws Exception + { + tester().assertThat(new PruneValuesColumns()) + .on(p -> + p.project( + Assignments.of(p.symbol("y"), expression("x")), + p.values( + ImmutableList.of(p.symbol("unused"), p.symbol("x")), + ImmutableList.of( + ImmutableList.of(expression("1"), expression("2")), + ImmutableList.of(expression("3"), expression("4")))))) + .matches( + project( + ImmutableMap.of("y", PlanMatchPattern.expression("x")), + values( + ImmutableList.of("x"), + ImmutableList.of( + ImmutableList.of(expression("2")), + ImmutableList.of(expression("4")))))); + } + + @Test + public void testAllOutputsReferenced() + throws Exception + { + tester().assertThat(new PruneValuesColumns()) + .on(p -> + p.project( + Assignments.of(p.symbol("y"), expression("x")), + p.values(p.symbol("x")))) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java new file mode 100644 index 0000000000000..4d51671328e59 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expressions; +import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; + +public class TestPushAggregationThroughOuterJoin + extends BaseRuleTest +{ + @Test + public void testPushesAggregationThroughLeftJoin() + { + tester().assertThat(new PushAggregationThroughOuterJoin()) + .on(p -> p.aggregation(ab -> ab + .source( + p.join( + JoinNode.Type.LEFT, + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), + p.values(p.symbol("COL2")), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), + ImmutableList.of(p.symbol("COL1"), p.symbol("COL2")), + Optional.empty(), + Optional.empty(), + Optional.empty() + )) + .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .addGroupingSet(p.symbol("COL1")))) + .matches( + project(ImmutableMap.of( + "COL1", expression("COL1"), + "COALESCE", expression("coalesce(AVG, AVG_NULL)")), + join(JoinNode.Type.INNER, ImmutableList.of(), + join(JoinNode.Type.LEFT, ImmutableList.of(equiJoinClause("COL1", "COL2")), + values(ImmutableMap.of("COL1", 0)), + aggregation( + ImmutableList.of(ImmutableList.of("COL2")), + ImmutableMap.of(Optional.of("AVG"), functionCall("avg", ImmutableList.of("COL2"))), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + values(ImmutableMap.of("COL2", 0)))), + aggregation( + ImmutableList.of(ImmutableList.of()), + ImmutableMap.of(Optional.of("AVG_NULL"), functionCall("avg", ImmutableList.of("null_literal"))), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + values(ImmutableMap.of("null_literal", 0)))))); + } + + @Test + public void testPushesAggregationThroughRightJoin() + { + tester().assertThat(new PushAggregationThroughOuterJoin()) + .on(p -> p.aggregation(ab -> ab + .source(p.join( + JoinNode.Type.RIGHT, + p.values(p.symbol("COL2")), + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("COL2"), p.symbol("COL1"))), + ImmutableList.of(p.symbol("COL2"), p.symbol("COL1")), + Optional.empty(), + Optional.empty(), + Optional.empty())) + .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .addGroupingSet(p.symbol("COL1")))) + .matches( + project(ImmutableMap.of( + "COALESCE", expression("coalesce(AVG, AVG_NULL)"), + "COL1", expression("COL1")), + join(JoinNode.Type.INNER, ImmutableList.of(), + join(JoinNode.Type.RIGHT, ImmutableList.of(equiJoinClause("COL2", "COL1")), + aggregation( + ImmutableList.of(ImmutableList.of("COL2")), + ImmutableMap.of(Optional.of("AVG"), functionCall("avg", ImmutableList.of("COL2"))), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + values(ImmutableMap.of("COL2", 0))), + values(ImmutableMap.of("COL1", 0))), + aggregation( + ImmutableList.of(ImmutableList.of()), + ImmutableMap.of( + Optional.of("AVG_NULL"), functionCall("avg", ImmutableList.of("null_literal"))), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + values(ImmutableMap.of("null_literal", 0)))))); + } + + @Test + public void testDoesNotFireWhenNotDistinct() + { + tester().assertThat(new PushAggregationThroughOuterJoin()) + .on(p -> p.aggregation(ab -> ab + .source(p.join( + JoinNode.Type.LEFT, + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"), expressions("11"))), + p.values(new Symbol("COL2")), + ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), + ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), + Optional.empty(), + Optional.empty(), + Optional.empty())) + .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .addGroupingSet(new Symbol("COL1")))) + .doesNotFire(); + } + + @Test + public void testDoesNotFireWhenGroupingOnInner() + { + tester().assertThat(new PushAggregationThroughOuterJoin()) + .on(p -> p.aggregation(ab -> ab + .source(p.join(JoinNode.Type.LEFT, + p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), + p.values(new Symbol("COL2"), new Symbol("COL3")), + ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), + ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), + Optional.empty(), + Optional.empty(), + Optional.empty())) + .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) + .addGroupingSet(new Symbol("COL1"), new Symbol("COL3"))) + ) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java new file mode 100644 index 0000000000000..898da02bfea52 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughMarkDistinct.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static java.util.Collections.emptyList; + +public class TestPushLimitThroughMarkDistinct + extends BaseRuleTest +{ + @Test + public void test() + throws Exception + { + tester().assertThat(new PushLimitThroughMarkDistinct()) + .on(p -> + p.limit( + 1, + p.markDistinct( + p.values(), p.symbol("foo"), emptyList()))) + .matches( + node(MarkDistinctNode.class, + node(LimitNode.class, + node(ValuesNode.class)))); + } + + @Test + public void testDoesNotFire() + throws Exception + { + tester().assertThat(new PushLimitThroughMarkDistinct()) + .on(p -> + p.markDistinct( + p.limit( + 1, + p.values()), + p.symbol("foo"), + emptyList())) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java new file mode 100644 index 0000000000000..e66e450e6b6b4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -0,0 +1,154 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPushProjectionThroughExchange + extends BaseRuleTest +{ + @Test + public void testDoesNotFireNoExchange() + throws Exception + { + tester().assertThat(new PushProjectionThroughExchange()) + .on(p -> + p.project( + Assignments.of(p.symbol("x"), new LongLiteral("3")), + p.values(p.symbol("a")))) + .doesNotFire(); + } + + @Test + public void testDoesNotFireNarrowingProjection() + throws Exception + { + tester().assertThat(new PushProjectionThroughExchange()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + + return p.project( + Assignments.builder() + .put(a, a.toSymbolReference()) + .put(b, b.toSymbolReference()) + .build(), + p.exchange(e -> e + .addSource(p.values(a, b, c)) + .addInputsSet(a, b, c) + .singleDistributionPartitioningScheme(a, b, c))); + }) + .doesNotFire(); + } + + @Test + public void testSimpleMultipleInputs() + throws Exception + { + tester().assertThat(new PushProjectionThroughExchange()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol c2 = p.symbol("c2"); + Symbol x = p.symbol("x"); + return p.project( + Assignments.of( + x, new LongLiteral("3"), + c2, new SymbolReference("c") + ), + p.exchange(e -> e + .addSource( + p.values(a)) + .addSource( + p.values(b)) + .addInputsSet(a) + .addInputsSet(b) + .singleDistributionPartitioningScheme(c))); + }) + .matches( + exchange( + project( + values(ImmutableList.of("a")) + ) + .withAlias("x1", expression("3")), + project( + values(ImmutableList.of("b")) + ) + .withAlias("x2", expression("3")) + ) + // verify that data originally on symbols aliased as x1 and x2 is part of exchange output + .withAlias("x1") + .withAlias("x2")); + } + + @Test + public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() + throws Exception + { + tester().assertThat(new PushProjectionThroughExchange()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol h = p.symbol("h"); + Symbol aTimes5 = p.symbol("a_times_5"); + Symbol bTimes5 = p.symbol("b_times_5"); + Symbol hTimes5 = p.symbol("h_times_5"); + return p.project( + Assignments.builder() + .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) + .put(bTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.MULTIPLY, new SymbolReference("b"), new LongLiteral("5"))) + .put(hTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.MULTIPLY, new SymbolReference("h"), new LongLiteral("5"))) + .build(), + p.exchange(e -> e + .addSource( + p.values(a, b, h)) + .addInputsSet(a, b, h) + .fixedHashDistributionParitioningScheme( + ImmutableList.of(a, b, h), + ImmutableList.of(b), + h))); + }) + .matches( + project( + exchange( + project( + values( + ImmutableList.of("a", "b", "h") + ) + ).withNumberOfOutputColumns(5) + .withAlias("b", expression("b")) + .withAlias("h", expression("h")) + .withAlias("a_times_5", expression("a * 5")) + .withAlias("b_times_5", expression("b * 5")) + .withAlias("h_times_5", expression("h * 5")) + ) + ).withNumberOfOutputColumns(3) + .withExactOutputs("a_times_5", "b_times_5", "h_times_5") + ); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java new file mode 100644 index 0000000000000..545ea11e36ee4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.union; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPushProjectionThroughUnion + extends BaseRuleTest +{ + @Test + public void testDoesNotFire() + throws Exception + { + tester().assertThat(new PushProjectionThroughUnion()) + .on(p -> + p.project( + Assignments.of(p.symbol("x"), new LongLiteral("3")), + p.values(p.symbol("a")))) + .doesNotFire(); + } + + @Test + public void test() + throws Exception + { + tester().assertThat(new PushProjectionThroughUnion()) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol cTimes3 = p.symbol("c_times_3"); + return p.project( + Assignments.of(cTimes3, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.MULTIPLY, c.toSymbolReference(), new LongLiteral("3"))), + p.union( + ImmutableList.of( + p.values(a), + p.values(b)), + ImmutableListMultimap.builder() + .put(c, a) + .put(c, b) + .build(), + ImmutableList.of(c))); + }) + .matches( + union( + project( + ImmutableMap.of("a_times_3", expression("a * 3")), + values(ImmutableList.of("a"))), + project( + ImmutableMap.of("b_times_3", expression("b * 3")), + values(ImmutableList.of("b")))) + // verify that data originally on symbols aliased as x1 and x2 is part of exchange output + .withNumberOfOutputColumns(1) + .withAlias("a_times_3") + .withAlias("b_times_3")); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveEmptyDelete.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java similarity index 66% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveEmptyDelete.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java index 7613ba5b5d398..25f9064ca17fd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveEmptyDelete.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveEmptyDelete.java @@ -11,28 +11,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.iterative.rule.test; +package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; -import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.tpch.TpchTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; +import static com.facebook.presto.sql.planner.iterative.rule.test.RuleTester.CATALOG_ID; +import static com.facebook.presto.sql.planner.iterative.rule.test.RuleTester.CONNECTOR_ID; + public class TestRemoveEmptyDelete + extends BaseRuleTest { - private final RuleTester tester = new RuleTester(); - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new RemoveEmptyDelete()) + tester().assertThat(new RemoveEmptyDelete()) .on(p -> p.tableDelete( new SchemaTableName("sch", "tab"), - p.tableScan(ImmutableList.of(), ImmutableMap.of()), + p.tableScan( + new TableHandle(CONNECTOR_ID, new TpchTableHandle(CATALOG_ID, "nation", 1.0)), + ImmutableList.of(), + ImmutableMap.of()), p.symbol("a", BigintType.BIGINT)) ) .doesNotFire(); @@ -41,7 +48,7 @@ public void testDoesNotFire() @Test public void test() { - tester.assertThat(new RemoveEmptyDelete()) + tester().assertThat(new RemoveEmptyDelete()) .on(p -> p.tableDelete( new SchemaTableName("sch", "tab"), p.values(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveFullSample.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java similarity index 72% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveFullSample.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java index e9b97521fb078..c64ffd062b6cd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRemoveFullSample.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveFullSample.java @@ -11,50 +11,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.iterative.rule.test; +package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.SampleNode.Type; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expressions; -import static io.airlift.testing.Closeables.closeAllRuntimeException; public class TestRemoveFullSample + extends BaseRuleTest { - private RuleTester tester; - - @BeforeClass - public void setUp() - { - tester = new RuleTester(); - } - - @AfterClass(alwaysRun = true) - public void tearDown() - { - closeAllRuntimeException(tester); - tester = null; - } - @Test public void testDoesNotFire() throws Exception { - tester.assertThat(new RemoveFullSample()) + tester().assertThat(new RemoveFullSample()) .on(p -> p.sample( 0.15, Type.BERNOULLI, - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -62,7 +44,7 @@ public void testDoesNotFire() public void test() throws Exception { - tester.assertThat(new RemoveFullSample()) + tester().assertThat(new RemoveFullSample()) .on(p -> p.sample( 1.0, @@ -70,7 +52,7 @@ public void test() p.filter( expression("b > 5"), p.values( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of( expressions("1", "10"), expressions("2", "11")))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java new file mode 100644 index 0000000000000..47b164d7e7b78 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarLateralNodes.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.emptyList; + +public class TestRemoveUnreferencedScalarLateralNodes + extends BaseRuleTest +{ + @Test + public void testRemoveUnreferencedInput() + { + tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) + .on(p -> p.lateral( + emptyList(), + p.values(new Symbol("x")), + p.values(emptyList(), ImmutableList.of(emptyList())))) + .matches(values("x")); + } + + @Test + public void testRemoveUnreferencedSubquery() + { + tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) + .on(p -> p.lateral( + emptyList(), + p.values(emptyList(), ImmutableList.of(emptyList())), + p.values(new Symbol("x")))) + .matches(values("x")); + } + + @Test + public void testDoesNotFire() + { + tester().assertThat(new RemoveUnreferencedScalarLateralNodes()) + .on(p -> p.lateral( + emptyList(), + p.values(), + p.values())) + .doesNotFire(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestSwapAdjacentWindowsByPartitionsOrder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java similarity index 80% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestSwapAdjacentWindowsByPartitionsOrder.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 2972f12b9a65a..e26301283431d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestSwapAdjacentWindowsByPartitionsOrder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -11,12 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner.iterative.rule.test; +package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; -import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsByPartitionsOrder; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; @@ -38,14 +38,13 @@ import static com.facebook.presto.sql.tree.FrameBound.Type.CURRENT_ROW; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; -public class TestSwapAdjacentWindowsByPartitionsOrder +public class TestSwapAdjacentWindowsBySpecifications + extends BaseRuleTest { - private final RuleTester tester = new RuleTester(); - private WindowNode.Frame frame; private Signature signature; - public TestSwapAdjacentWindowsByPartitionsOrder() + public TestSwapAdjacentWindowsBySpecifications() { frame = new WindowNode.Frame(WindowFrame.Type.RANGE, UNBOUNDED_PRECEDING, Optional.empty(), CURRENT_ROW, Optional.empty()); @@ -63,8 +62,8 @@ public TestSwapAdjacentWindowsByPartitionsOrder() public void doesNotFireOnPlanWithoutWindowFunctions() throws Exception { - tester.assertThat(new SwapAdjacentWindowsByPartitionsOrder()) - .on(p -> p.values(p.symbol("a", BIGINT))) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) + .on(p -> p.values(p.symbol("a"))) .doesNotFire(); } @@ -72,14 +71,14 @@ public void doesNotFireOnPlanWithoutWindowFunctions() public void doesNotFireOnPlanWithSingleWindowNode() throws Exception { - tester.assertThat(new SwapAdjacentWindowsByPartitionsOrder()) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT)), + ImmutableList.of(p.symbol("a")), ImmutableList.of(), ImmutableMap.of()), - ImmutableMap.of(p.symbol("avg_1", BIGINT), + ImmutableMap.of(p.symbol("avg_1"), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), ImmutableList.of()), signature, frame)), - p.values(p.symbol("a", BIGINT)))) + p.values(p.symbol("a")))) .doesNotFire(); } @@ -96,26 +95,26 @@ public void subsetComesFirst() Optional windowAB = Optional.of(new Window(ImmutableList.of(new SymbolReference("a"), new SymbolReference("b")), Optional.empty(), Optional.empty())); Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); - tester.assertThat(new SwapAdjacentWindowsByPartitionsOrder()) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT)), + ImmutableList.of(p.symbol("a")), ImmutableList.of(), ImmutableMap.of()), ImmutableMap.of(p.symbol("avg_1", DOUBLE), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), signature, frame)), p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(), ImmutableMap.of()), ImmutableMap.of(p.symbol("avg_2", DOUBLE), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowAB, false, ImmutableList.of(new SymbolReference("b"))), signature, frame)), - p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) + p.values(p.symbol("a"), p.symbol("b"))))) .matches(window(specificationAB, - ImmutableList.of(functionCall("avg", Optional.empty(), ImmutableList.of(columnBAlias))), - window(specificationA, - ImmutableList.of(functionCall("avg", Optional.empty(), ImmutableList.of(columnAAlias))), - values(ImmutableMap.of(columnAAlias, 0, columnBAlias, 1))))); + ImmutableList.of(functionCall("avg", Optional.empty(), ImmutableList.of(columnBAlias))), + window(specificationA, + ImmutableList.of(functionCall("avg", Optional.empty(), ImmutableList.of(columnAAlias))), + values(ImmutableMap.of(columnAAlias, 0, columnBAlias, 1))))); } @Test @@ -124,21 +123,21 @@ public void dependentWindowsAreNotReordered() { Optional windowA = Optional.of(new Window(ImmutableList.of(new SymbolReference("a")), Optional.empty(), Optional.empty())); - tester.assertThat(new SwapAdjacentWindowsByPartitionsOrder()) + tester().assertThat(new SwapAdjacentWindowsBySpecifications()) .on(p -> p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT)), + ImmutableList.of(p.symbol("a")), ImmutableList.of(), ImmutableMap.of()), - ImmutableMap.of(p.symbol("avg_1", BIGINT), + ImmutableMap.of(p.symbol("avg_1"), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("avg_2"))), signature, frame)), p.window(new WindowNode.Specification( - ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), + ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(), ImmutableMap.of()), - ImmutableMap.of(p.symbol("avg_2", BIGINT), + ImmutableMap.of(p.symbol("avg_2"), new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), signature, frame)), - p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) + p.values(p.symbol("a"), p.symbol("b"))))) .doesNotFire(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java new file mode 100644 index 0000000000000..af5e77648d972 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.type.TypeRegistry; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.assignUniqueId; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestTransformCorrelatedScalarAggregationToJoin +{ + private RuleTester tester; + private FunctionRegistry functionRegistry; + private Rule rule; + + @BeforeClass + public void setUp() + { + tester = new RuleTester(); + TypeRegistry typeRegistry = new TypeRegistry(); + functionRegistry = new FunctionRegistry(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); + rule = new TransformCorrelatedScalarAggregationToJoin(functionRegistry); + } + + @Test + public void doesNotFireOnPlanWithoutApplyNode() + { + tester.assertThat(rule) + .on(p -> p.values(p.symbol("a"))) + .doesNotFire(); + } + + @Test + public void doesNotFireOnCorrelatedWithoutAggregation() + { + tester.assertThat(rule) + .on(p -> p.lateral( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.values(p.symbol("a")))) + .doesNotFire(); + } + + @Test + public void doesNotFireOnUncorrelated() + { + tester.assertThat(rule) + .on(p -> p.lateral( + ImmutableList.of(), + p.values(p.symbol("a")), + p.values(p.symbol("b")))) + .doesNotFire(); + } + + @Test + public void doesNotFireOnCorrelatedWithNonScalarAggregation() + { + tester.assertThat(rule) + .on(p -> p.lateral( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.aggregation(ab -> ab + .source(p.values(p.symbol("a"), p.symbol("b"))) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .addGroupingSet(p.symbol("b"))))) + .doesNotFire(); + } + + @Test + public void rewritesOnSubqueryWithoutProjection() + { + tester.assertThat(rule) + .on(p -> p.lateral( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.aggregation(ab -> ab + .source(p.values(p.symbol("a"), p.symbol("b"))) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .globalGrouping()))) + .matches( + project(ImmutableMap.of("sum_1", expression("sum_1"), "corr", expression("corr")), + aggregation(ImmutableMap.of("sum_1", functionCall("sum", ImmutableList.of("a"))), + join(JoinNode.Type.LEFT, + ImmutableList.of(), + assignUniqueId("unique", + values(ImmutableMap.of("corr", 0))), + project(ImmutableMap.of("non_null", expression("true")), + values(ImmutableMap.of("a", 0, "b", 1))))))); + } + + @Test + public void rewritesOnSubqueryWithProjection() + { + tester.assertThat(rule) + .on(p -> p.lateral( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.project(Assignments.of(p.symbol("expr"), p.expression("sum + 1")), + p.aggregation(ab -> ab + .source(p.values(p.symbol("a"), p.symbol("b"))) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .globalGrouping())))) + .matches( + project(ImmutableMap.of("corr", expression("corr"), "expr", expression("(\"sum_1\" + 1)")), + aggregation(ImmutableMap.of("sum_1", functionCall("sum", ImmutableList.of("a"))), + join(JoinNode.Type.LEFT, + ImmutableList.of(), + assignUniqueId("unique", + values(ImmutableMap.of("corr", 0))), + project(ImmutableMap.of("non_null", expression("true")), + values(ImmutableMap.of("a", 0, "b", 1))))))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarApply.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarLateralJoin.java similarity index 69% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarApply.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarLateralJoin.java index 2621b1f0182fd..d4c34a0c913fb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarApply.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToScalarLateralJoin.java @@ -27,17 +27,16 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; -import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.apply; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.lateral; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static io.airlift.testing.Closeables.closeAllRuntimeException; -public class TestTransformExistsApplyToScalarApply +public class TestTransformExistsApplyToScalarLateralJoin { private RuleTester tester; private Rule transformExistsApplyToScalarApply; @@ -48,7 +47,7 @@ public void setUp() tester = new RuleTester(); TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); - transformExistsApplyToScalarApply = new TransformExistsApplyToScalarApply(registry); + transformExistsApplyToScalarApply = new TransformExistsApplyToLateralNode(registry); } @AfterClass(alwaysRun = true) @@ -63,36 +62,15 @@ public void tearDown() public void testDoesNotFire() { tester.assertThat(transformExistsApplyToScalarApply) - .on(p -> p.values(p.symbol("a", BIGINT))) + .on(p -> p.values(p.symbol("a"))) .doesNotFire(); tester.assertThat(transformExistsApplyToScalarApply) .on(p -> - p.apply( - Assignments.identity(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), - ImmutableList.of(p.symbol("a", BIGINT)), - p.values(p.symbol("a", BIGINT)), - p.values(p.symbol("a", BIGINT))) - ) - .doesNotFire(); - - tester.assertThat(transformExistsApplyToScalarApply) - .on(p -> - p.apply( - Assignments.identity(p.symbol("a", BIGINT)), - ImmutableList.of(p.symbol("a", BIGINT)), - p.values(p.symbol("a", BIGINT)), - p.values(p.symbol("a", BIGINT))) - ) - .doesNotFire(); - - tester.assertThat(transformExistsApplyToScalarApply) - .on(p -> - p.apply( - Assignments.of(p.symbol("b", BOOLEAN), expression("\"a\"")), - ImmutableList.of(), - p.values(), - p.values(p.symbol("a", BIGINT))) + p.lateral( + ImmutableList.of(p.symbol("a")), + p.values(p.symbol("a")), + p.values(p.symbol("a"))) ) .doesNotFire(); } @@ -107,11 +85,10 @@ public void testRewrite() Assignments.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT \"a\")")), ImmutableList.of(), p.values(), - p.values(p.symbol("a", BIGINT))) + p.values(p.symbol("a"))) ) - .matches(apply( + .matches(lateral( ImmutableList.of(), - ImmutableMap.of("b", PlanMatchPattern.expression("\"b\"")), values(ImmutableMap.of()), project( ImmutableMap.of("b", PlanMatchPattern.expression("(\"count_expr\" > CAST(0 AS bigint))")), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java new file mode 100644 index 0000000000000..47255b4fc74b4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.tree.ExistsPredicate; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.SymbolReference; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.emptyList; + +public class TestTransformUncorrelatedInPredicateSubqueryToSemiJoin + extends BaseRuleTest +{ + @Test + public void testDoesNotFireOnNoCorrelation() + { + tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) + .on(p -> p.apply( + Assignments.of(), + emptyList(), + p.values(), + p.values())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnNonInPredicateSubquery() + { + tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) + .on(p -> p.apply( + Assignments.of(p.symbol("x"), new ExistsPredicate(new LongLiteral("1"))), + emptyList(), + p.values(), + p.values())) + .doesNotFire(); + } + + @Test + public void testFiresForInPredicate() + { + tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) + .on(p -> p.apply( + Assignments.of( + p.symbol("x"), + new InPredicate( + new SymbolReference("y"), + new SymbolReference("z"))), + emptyList(), + p.values(p.symbol("y")), + p.values(p.symbol("z")))) + .matches(node(SemiJoinNode.class, values("y"), values("z"))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java new file mode 100644 index 0000000000000..9ea4319eb045b --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedLateralToJoin.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.emptyList; + +public class TestTransformUncorrelatedLateralToJoin + extends BaseRuleTest +{ + @Test + public void test() + { + tester() + .assertThat(new TransformUncorrelatedLateralToJoin()) + .on(p -> p.lateral(emptyList(), p.values(), p.values())) + .matches(join(JoinNode.Type.INNER, emptyList(), values(), values())); + } + + @Test + public void testDoesNotFire() + { + Symbol symbol = new Symbol("x"); + tester() + .assertThat(new TransformUncorrelatedLateralToJoin()) + .on(p -> p.lateral(ImmutableList.of(symbol), p.values(symbol), p.values())) + .doesNotFire(); + } +} diff --git a/presto-hive-hadoop1/src/test/java/com/facebook/presto/hive/TestHiveClient.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java similarity index 55% rename from presto-hive-hadoop1/src/test/java/com/facebook/presto/hive/TestHiveClient.java rename to presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java index b8a7bc888be69..0a0f320b25635 100644 --- a/presto-hive-hadoop1/src/test/java/com/facebook/presto/hive/TestHiveClient.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java @@ -11,20 +11,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.hive; +package com.facebook.presto.sql.planner.iterative.rule.test; +import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; -@Test(groups = "hive") -public class TestHiveClient - extends AbstractTestHiveClient +import static io.airlift.testing.Closeables.closeAllRuntimeException; + +public abstract class BaseRuleTest { - @Parameters({"hive.hadoop1.metastoreHost", "hive.hadoop1.metastorePort", "hive.hadoop1.databaseName", "hive.hadoop1.timeZone"}) + private RuleTester tester; + @BeforeClass - public void initialize(String host, int port, String databaseName, String timeZone) + public final void setUp() + { + tester = new RuleTester(); + } + + @AfterClass(alwaysRun = true) + public final void tearDown() + { + closeAllRuntimeException(tester); + tester = null; + } + + protected RuleTester tester() { - setup(host, port, databaseName, timeZone); + return tester; } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index fd6895e184192..140be79ee2708 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -28,7 +28,6 @@ import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TestingTableHandle; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; @@ -37,20 +36,27 @@ import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; +import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ListMultimap; import java.util.ArrayList; import java.util.Arrays; @@ -62,9 +68,11 @@ import java.util.function.Consumer; import java.util.stream.Stream; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; @@ -98,6 +106,11 @@ public LimitNode limit(long limit, PlanNode source) return new LimitNode(idAllocator.getNextId(), source, limit, false); } + public MarkDistinctNode markDistinct(PlanNode source, Symbol markerSymbol, List distinctSymbols) + { + return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.empty()); + } + public SampleNode sample(double sampleRatio, SampleNode.Type type, PlanNode source) { return new SampleNode(idAllocator.getNextId(), source, sampleRatio, type); @@ -108,6 +121,16 @@ public ProjectNode project(Assignments assignments, PlanNode source) return new ProjectNode(idAllocator.getNextId(), source, assignments); } + public MarkDistinctNode markDistinct(Symbol markerSymbol, List distinctSymbols, PlanNode source) + { + return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.empty()); + } + + public MarkDistinctNode markDistinct(Symbol markerSymbol, List distinctSymbols, Symbol hashSymbol, PlanNode source) + { + return new MarkDistinctNode(idAllocator.getNextId(), source, markerSymbol, distinctSymbols, Optional.of(hashSymbol)); + } + public FilterNode filter(Expression predicate, PlanNode source) { return new FilterNode(idAllocator.getNextId(), source, predicate); @@ -124,8 +147,8 @@ public class AggregationBuilder { private PlanNode source; private Map assignments = new HashMap<>(); - private List> groupingSets; - private Step step; + private List> groupingSets = new ArrayList<>(); + private Step step = Step.SINGLE; private Optional hashSymbol = Optional.empty(); private Optional groupIdSymbol = Optional.empty(); @@ -156,7 +179,19 @@ public AggregationBuilder globalGrouping() public AggregationBuilder groupingSets(List> groupingSets) { - this.groupingSets = ImmutableList.copyOf(groupingSets); + checkState(this.groupingSets.isEmpty(), "groupingSets already defined"); + this.groupingSets.addAll(groupingSets); + return this; + } + + public AggregationBuilder addGroupingSet(Symbol... symbols) + { + return addGroupingSet(ImmutableList.copyOf(symbols)); + } + + public AggregationBuilder addGroupingSet(List symbols) + { + groupingSets.add(ImmutableList.copyOf(symbols)); return this; } @@ -180,6 +215,7 @@ public AggregationBuilder groupIdSymbol(Symbol groupIdSymbol) protected AggregationNode build() { + checkState(!groupingSets.isEmpty(), "No grouping sets defined; use globalGrouping/addGroupingSet/addEmptyGroupingSet method"); return new AggregationNode( idAllocator.getNextId(), source, @@ -196,14 +232,23 @@ public ApplyNode apply(Assignments subqueryAssignments, List correlation return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation); } + public LateralJoinNode lateral(List correlation, PlanNode input, PlanNode subquery) + { + return new LateralJoinNode(idAllocator.getNextId(), input, subquery, correlation, LateralJoinNode.Type.INNER); + } + public TableScanNode tableScan(List symbols, Map assignments) + { + TableHandle tableHandle = new TableHandle(new ConnectorId("testConnector"), new TestingTableHandle()); + return tableScan(tableHandle, symbols, assignments); + } + + public TableScanNode tableScan(TableHandle tableHandle, List symbols, Map assignments) { Expression originalConstraint = null; return new TableScanNode( idAllocator.getNextId(), - new TableHandle( - new ConnectorId("testConnector"), - new TestingTableHandle()), + tableHandle, symbols, assignments, Optional.empty(), @@ -247,6 +292,26 @@ public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) .addInputsSet(child.getOutputSymbols())); } + public SemiJoinNode semiJoin( + Symbol sourceJoinSymbol, + Symbol filteringSourceJoinSymbol, + Symbol semiJoinOutput, + Optional sourceHashSymbol, + Optional filteringSourceHashSymbol, + PlanNode source, + PlanNode filteringSource) + { + return new SemiJoinNode(idAllocator.getNextId(), + source, + filteringSource, + sourceJoinSymbol, + filteringSourceJoinSymbol, + semiJoinOutput, + sourceHashSymbol, + filteringSourceHashSymbol, + Optional.empty()); + } + public ExchangeNode exchange(Consumer exchangeBuilderConsumer) { ExchangeBuilder exchangeBuilder = new ExchangeBuilder(); @@ -323,11 +388,52 @@ protected ExchangeNode build() } } + public JoinNode join(JoinNode.Type joinType, PlanNode left, PlanNode right, JoinNode.EquiJoinClause... criteria) + { + return new JoinNode(idAllocator.getNextId(), + joinType, + left, + right, + ImmutableList.copyOf(criteria), + ImmutableList.builder() + .addAll(left.getOutputSymbols()) + .addAll(right.getOutputSymbols()) + .build(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + } + + public JoinNode join( + JoinNode.Type type, + PlanNode left, + PlanNode right, + List criteria, + List outputSymbols, + Optional filter, + Optional leftHashSymbol, + Optional rightHashSymbol) + { + return new JoinNode(idAllocator.getNextId(), type, left, right, criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, Optional.empty()); + } + + public UnionNode union(List sources, ListMultimap outputsToInputs, List outputs) + { + return new UnionNode(idAllocator.getNextId(), (List) sources, outputsToInputs, outputs); + } + + public Symbol symbol(String name) + { + return symbol(name, BIGINT); + } + public Symbol symbol(String name, Type type) { Symbol symbol = new Symbol(name); - Type old = symbols.get(symbol); + Type old = symbols.put(symbol, type); if (old != null && !old.equals(type)) { throw new IllegalArgumentException(format("Symbol '%s' already registered with type '%s'", name, old)); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 41589f410d7f9..e078d4d636764 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -14,16 +14,23 @@ package com.facebook.presto.sql.planner.iterative.rule.test; import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.PlanNodeCost; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.security.AccessControl; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Memo; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; +import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; import java.util.Map; @@ -31,12 +38,15 @@ import java.util.function.Function; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; +import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; import static org.testng.Assert.fail; public class RuleAssert { private final Metadata metadata; + private final CostCalculator costCalculator; private Session session; private final Rule rule; @@ -44,12 +54,17 @@ public class RuleAssert private Map symbols; private PlanNode plan; + private final TransactionManager transactionManager; + private final AccessControl accessControl; - public RuleAssert(Metadata metadata, Session session, Rule rule) + public RuleAssert(Metadata metadata, CostCalculator costCalculator, Session session, Rule rule, TransactionManager transactionManager, AccessControl accessControl) { this.metadata = metadata; + this.costCalculator = costCalculator; this.session = session; this.rule = rule; + this.transactionManager = transactionManager; + this.accessControl = accessControl; } public RuleAssert setSystemProperty(String key, String value) @@ -77,37 +92,35 @@ public RuleAssert on(Function planProvider) public void doesNotFire() { - SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); - Optional result = rule.apply(plan, x -> x, idAllocator, symbolAllocator, session); + RuleApplication ruleApplication = applyRule(); - if (result.isPresent()) { + if (ruleApplication.wasRuleApplied()) { fail(String.format( "Expected %s to not fire for:\n%s", rule.getClass().getName(), - PlanPrinter.textLogicalPlan(plan, symbolAllocator.getTypes(), metadata, session, 2))); + inTransaction(session -> PlanPrinter.textLogicalPlan(plan, ruleApplication.types, metadata, costCalculator, session, 2)))); } } public void matches(PlanMatchPattern pattern) { - SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); - Optional result = rule.apply(plan, x -> x, idAllocator, symbolAllocator, session); - Map types = symbolAllocator.getTypes(); + RuleApplication ruleApplication = applyRule(); + Map types = ruleApplication.types; - if (!result.isPresent()) { + if (!ruleApplication.wasRuleApplied()) { fail(String.format( "%s did not fire for:\n%s", rule.getClass().getName(), - PlanPrinter.textLogicalPlan(plan, types, metadata, session, 2))); + formatPlan(plan, types))); } - PlanNode actual = result.get(); + PlanNode actual = ruleApplication.getResult(); if (actual == plan) { // plans are not comparable, so we can only ensure they are not the same instance fail(String.format( "%s: rule fired but return the original plan:\n%s", rule.getClass().getName(), - PlanPrinter.textLogicalPlan(plan, types, metadata, session, 2))); + formatPlan(plan, types))); } if (!ImmutableSet.copyOf(plan.getOutputSymbols()).equals(ImmutableSet.copyOf(actual.getOutputSymbols()))) { @@ -120,6 +133,65 @@ public void matches(PlanMatchPattern pattern) actual.getOutputSymbols())); } - assertPlan(session, metadata, new Plan(actual, types), pattern); + inTransaction(session -> { + Map planNodeCosts = costCalculator.calculateCostForPlan(session, types, actual); + assertPlan(session, metadata, costCalculator, new Plan(actual, types, planNodeCosts), ruleApplication.lookup, pattern); + return null; + }); + } + + private RuleApplication applyRule() + { + SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); + Memo memo = new Memo(idAllocator, plan); + Lookup lookup = Lookup.from(memo::resolve); + + if (!rule.getPattern().matches(plan)) { + return new RuleApplication(lookup, symbolAllocator.getTypes(), Optional.empty()); + } + + Optional result = inTransaction(session -> rule.apply(memo.getNode(memo.getRootGroup()), lookup, idAllocator, symbolAllocator, session)); + + return new RuleApplication(lookup, symbolAllocator.getTypes(), result); + } + + private String formatPlan(PlanNode plan, Map types) + { + return inTransaction(session -> PlanPrinter.textLogicalPlan(plan, types, metadata, costCalculator, session, 2)); + } + + private T inTransaction(Function transactionSessionConsumer) + { + return transaction(transactionManager, accessControl) + .singleStatement() + .execute(session, session -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); + return transactionSessionConsumer.apply(session); + }); + } + + private static class RuleApplication + { + private final Lookup lookup; + private final Map types; + private final Optional result; + + public RuleApplication(Lookup lookup, Map types, Optional result) + { + this.lookup = requireNonNull(lookup, "lookup is null"); + this.types = requireNonNull(types, "types is null"); + this.result = requireNonNull(result, "result is null"); + } + + private boolean wasRuleApplied() + { + return result.isPresent(); + } + + public PlanNode getResult() + { + return result.orElseThrow(() -> new IllegalStateException("Rule was not applied")); + } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java index b1bf754b8ab16..aa41b2ed8c152 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java @@ -14,10 +14,14 @@ package com.facebook.presto.sql.planner.iterative.rule.test; import com.facebook.presto.Session; +import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.security.AccessControl; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.tpch.TpchConnectorFactory; +import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableMap; import java.io.Closeable; @@ -27,14 +31,20 @@ public class RuleTester implements Closeable { + public static final String CATALOG_ID = "local"; + public static final ConnectorId CONNECTOR_ID = new ConnectorId(CATALOG_ID); + private final Metadata metadata; + private final CostCalculator costCalculator; private final Session session; private final LocalQueryRunner queryRunner; + private final TransactionManager transactionManager; + private final AccessControl accessControl; public RuleTester() { session = testSessionBuilder() - .setCatalog("local") + .setCatalog(CATALOG_ID) .setSchema("tiny") .setSystemProperty("task_concurrency", "1") // these tests don't handle exchanges from local parallel .build(); @@ -45,11 +55,14 @@ public RuleTester() ImmutableMap.of()); this.metadata = queryRunner.getMetadata(); + this.costCalculator = queryRunner.getCostCalculator(); + this.transactionManager = queryRunner.getTransactionManager(); + this.accessControl = queryRunner.getAccessControl(); } public RuleAssert assertThat(Rule rule) { - return new RuleAssert(metadata, session, rule); + return new RuleAssert(metadata, costCalculator, session, rule, transactionManager, accessControl); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCountConstantOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCountConstantOptimizer.java index 90258250a4b7b..881dd58b02c08 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCountConstantOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCountConstantOptimizer.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; @@ -49,8 +50,10 @@ public void testCountConstantOptimizer() PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); Symbol countAggregationSymbol = new Symbol("count"); Signature countAggregationSignature = new Signature("count", FunctionKind.AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT)); - ImmutableMap aggregations = ImmutableMap.of(countAggregationSymbol, new FunctionCall(QualifiedName.of("count"), ImmutableList.of(new SymbolReference("expr")))); - ImmutableMap functions = ImmutableMap.of(countAggregationSymbol, countAggregationSignature); + ImmutableMap aggregations = ImmutableMap.of(countAggregationSymbol, new Aggregation( + new FunctionCall(QualifiedName.of("count"), ImmutableList.of(new SymbolReference("expr"))), + countAggregationSignature, + Optional.empty())); ValuesNode valuesNode = new ValuesNode(planNodeIdAllocator.getNextId(), ImmutableList.of(new Symbol("col")), ImmutableList.of(ImmutableList.of())); AggregationNode eligiblePlan = new AggregationNode( @@ -60,8 +63,6 @@ public void testCountConstantOptimizer() valuesNode, Assignments.of(new Symbol("expr"), new LongLiteral("42"))), aggregations, - functions, - ImmutableMap.of(), ImmutableList.of(ImmutableList.of()), AggregationNode.Step.INTERMEDIATE, Optional.empty(), @@ -70,6 +71,7 @@ public void testCountConstantOptimizer() assertTrue(((AggregationNode) optimizer.optimize(eligiblePlan, TEST_SESSION, ImmutableMap.of(), new SymbolAllocator(), new PlanNodeIdAllocator())) .getAggregations() .get(countAggregationSymbol) + .getCall() .getArguments() .isEmpty()); @@ -80,8 +82,6 @@ public void testCountConstantOptimizer() valuesNode, Assignments.of(new Symbol("expr"), new FunctionCall(QualifiedName.of("function"), ImmutableList.of(new Identifier("x"))))), aggregations, - functions, - ImmutableMap.of(), ImmutableList.of(ImmutableList.of()), AggregationNode.Step.INTERMEDIATE, Optional.empty(), @@ -90,6 +90,7 @@ public void testCountConstantOptimizer() assertFalse(((AggregationNode) optimizer.optimize(ineligiblePlan, TEST_SESSION, ImmutableMap.of(), new SymbolAllocator(), new PlanNodeIdAllocator())) .getAggregations() .get(countAggregationSymbol) + .getCall() .getArguments() .isEmpty()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java index f436ddc62049f..3db490c3443bd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestEliminateSorts.java @@ -92,13 +92,12 @@ public void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern new UnaliasSymbolReferences(), new AddExchanges(queryRunner.getMetadata(), new SqlParser()), new PruneUnreferencedOutputs(), - new MergeProjections(), new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections())) ); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java index a6915ac4410e6..aad65cf5cf2dd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestExpressionEquivalence.java @@ -29,7 +29,7 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; -import static com.facebook.presto.sql.planner.DependencyExtractor.extractUnique; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUnique; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; import static org.testng.Assert.assertFalse; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java index e3677b9986d84..465ef6a30897b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java @@ -19,7 +19,7 @@ import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.block.SortOrder; -import com.facebook.presto.sql.planner.TestingColumnHandle; +import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.DeserializationContext; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java index e219e6dddc81a..3150ac96759d7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergeWindows.java @@ -21,7 +21,9 @@ import com.facebook.presto.sql.planner.assertions.PlanAssert; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; +import com.facebook.presto.sql.planner.iterative.rule.MergeAdjacentWindows; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; +import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsBySpecifications; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.FrameBound; @@ -34,6 +36,7 @@ import org.testng.annotations.Test; import java.util.List; +import java.util.Map; import java.util.Optional; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; @@ -93,6 +96,13 @@ public class TestMergeWindows public TestMergeWindows() { + this(ImmutableMap.of()); + } + + public TestMergeWindows(Map sessionProperties) + { + super(sessionProperties); + specificationA = specification( ImmutableList.of(SUPPKEY_ALIAS), ImmutableList.of(ORDERKEY_ALIAS), @@ -166,13 +176,13 @@ public void testIdenticalWindowSpecificationsABA() assertUnitPlan(sql, anyTree( - window(specificationB, + window(specificationA, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), - window(specificationA, - ImmutableList.of( functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS)), functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(specificationB, + ImmutableList.of( + functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQSS)))); } @@ -241,13 +251,12 @@ public void testIdenticalWindowSpecificationsDefaultFrame() assertUnitPlan(sql, anyTree( - window(specificationD, + window(specificationC, ImmutableList.of( - functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))), - window(specificationC, - ImmutableList.of( functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS)), functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(specificationD, + ImmutableList.of(functionCall("sum", UNSPECIFIED_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQSS)))); } @@ -394,10 +403,10 @@ public void testNotMergeDifferentPartition() assertUnitPlan(sql, anyTree( - window(specificationC, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), - window(specificationA, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(specificationA, ImmutableList.of( + functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(specificationC, ImmutableList.of( + functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DOQS)))); } @@ -463,11 +472,11 @@ public void testNotMergeDifferentNullOrdering() assertUnitPlan(sql, anyTree( - window(specificationC, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), - window(specificationA, ImmutableList.of( - functionCall("sum", COMMON_FRAME, ImmutableList.of(EXTENDEDPRICE_ALIAS)), - functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(specificationA, ImmutableList.of( + functionCall("sum", COMMON_FRAME, ImmutableList.of(EXTENDEDPRICE_ALIAS)), + functionCall("sum", COMMON_FRAME, ImmutableList.of(DISCOUNT_ALIAS))), + window(specificationC, ImmutableList.of( + functionCall("sum", COMMON_FRAME, ImmutableList.of(QUANTITY_ALIAS))), LINEITEM_TABLESCAN_DEOQS)))); } @@ -476,12 +485,14 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter LocalQueryRunner queryRunner = getQueryRunner(); List optimizers = ImmutableList.of( new UnaliasSymbolReferences(), - new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections())), - new MergeWindows(), + new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of( + new RemoveRedundantIdentityProjections(), + new SwapAdjacentWindowsBySpecifications(), + new MergeAdjacentWindows())), new PruneUnreferencedOutputs()); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java index 8c637b96e0a7d..6f5c35f3bcdf9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMixedDistinctAggregationOptimizer.java @@ -146,7 +146,7 @@ public void assertUnitPlan(String sql, PlanMatchPattern pattern) queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java index 6672c6df6a579..543bc29c34dfd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderWindows.java @@ -22,7 +22,7 @@ import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; -import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsByPartitionsOrder; +import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsBySpecifications; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.WindowFrame; import com.facebook.presto.testing.LocalQueryRunner; @@ -276,11 +276,11 @@ private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern patter new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of( new RemoveRedundantIdentityProjections(), - new SwapAdjacentWindowsByPartitionsOrder())), + new SwapAdjacentWindowsBySpecifications())), new PruneUnreferencedOutputs()); queryRunner.inTransaction(transactionSession -> { Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers); - PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern); + PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getCostCalculator(), actualPlan, pattern); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java index 31e64c670bc8d..54ef7f6fff88b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSetFlatteningOptimizer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.optimizations; +import com.facebook.presto.sql.planner.LogicalPlanner; import com.facebook.presto.sql.planner.StatsRecorder; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; @@ -130,6 +131,6 @@ public void assertPlan(String sql, PlanMatchPattern pattern) new PruneUnreferencedOutputs(), new IterativeOptimizer(new StatsRecorder(), ImmutableSet.of(new RemoveRedundantIdentityProjections())), new SetFlatteningOptimizer()); - assertPlanWithOptimizers(sql, pattern, optimizers); + assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED, pattern, optimizers); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java index 30e3e58704ed5..6ed6702705fc9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSimplifyExpressions.java @@ -15,10 +15,10 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.tree.Expression; @@ -139,7 +139,7 @@ private static Expression simplifyExpressions(Expression expression) private static Map booleanSymbolTypeMapFor(Expression expression) { - return DependencyExtractor.extractUnique(expression).stream() + return SymbolsExtractor.extractUnique(expression).stream() .collect(Collectors.toMap(symbol -> symbol, symbol -> BOOLEAN)); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java index 08aedf7e9e731..bd805e17d91b6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnion.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TopNNode; import com.google.common.collect.Iterables; import org.testng.annotations.Test; @@ -29,6 +30,7 @@ import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -64,6 +66,55 @@ public void testSimpleUnion() assertPlanIsFullyDistributed(plan); } + @Test + public void testUnionUnderTopN() + { + Plan plan = plan( + "SELECT * FROM (" + + " SELECT regionkey FROM nation " + + " UNION ALL " + + " SELECT nationkey FROM nation" + + ") t(a) " + + "ORDER BY a LIMIT 1", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, + false); + + List remotes = searchFrom(plan.getRoot()) + .where(TestUnion::isRemoteExchange) + .findAll(); + + assertEquals(remotes.size(), 1, "There should be exactly one RemoteExchange"); + assertEquals(((ExchangeNode) Iterables.getOnlyElement(remotes)).getType(), GATHER); + + int numberOfpartialTopN = searchFrom(plan.getRoot()) + .where(planNode -> planNode instanceof TopNNode && ((TopNNode) planNode).getStep().equals(TopNNode.Step.PARTIAL)) + .count(); + assertEquals(numberOfpartialTopN, 2, "There should be exactly two partial TopN nodes"); + assertPlanIsFullyDistributed(plan); + } + + @Test + public void testUnionOverSingleNodeAggregationAndUnion() + { + Plan plan = plan( + "SELECT count(*) FROM (" + + "SELECT 1 FROM nation GROUP BY regionkey " + + "UNION ALL (" + + " SELECT 1 FROM nation " + + " UNION ALL " + + " SELECT 1 FROM nation))", + LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, + false); + + List remotes = searchFrom(plan.getRoot()) + .where(TestUnion::isRemoteExchange) + .findAll(); + + assertEquals(remotes.size(), 2, "There should be exactly two RemoteExchanges"); + assertEquals(((ExchangeNode) remotes.get(0)).getType(), GATHER); + assertEquals(((ExchangeNode) remotes.get(1)).getType(), REPARTITION); + } + @Test public void testPartialAggregationsWithUnion() { @@ -116,8 +167,8 @@ private void assertPlanIsFullyDistributed(Plan plan) .skipOnlyWhen(TestUnion::isNotRemoteGatheringExchange) .findAll() .stream() - .noneMatch(planNode -> planNode instanceof AggregationNode || planNode instanceof JoinNode), - "There is an Aggregation or Join between output and first REMOTE GATHER ExchangeNode"); + .noneMatch(this::shouldBeDistributed), + "There is a node that should be distributed between output and first REMOTE GATHER ExchangeNode"); List gathers = searchFrom(plan.getRoot()) .where(TestUnion::isRemoteGatheringExchange) @@ -128,6 +179,21 @@ private void assertPlanIsFullyDistributed(Plan plan) assertEquals(gathers.size(), 1, "Only a single REMOTE GATHER was expected"); } + private boolean shouldBeDistributed(PlanNode planNode) + { + if (planNode instanceof JoinNode) { + return true; + } + if (planNode instanceof AggregationNode) { + // TODO: differentiate aggregation with empty grouping set + return true; + } + if (planNode instanceof TopNNode) { + return ((TopNNode) planNode).getStep() == TopNNode.Step.PARTIAL; + } + return false; + } + private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) { List fragments = searchFrom(plan.getRoot()) diff --git a/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java b/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java index db2dfa64d124f..0b4d76118b380 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/AbstractTestType.java @@ -16,6 +16,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; import com.google.common.collect.ImmutableMap; import io.airlift.slice.DynamicSliceOutput; diff --git a/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java b/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java index 4b2bcd80ecb79..4b064a8531091 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/BenchmarkDecimalOperators.java @@ -30,7 +30,7 @@ import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.util.maps.IdentityLinkedHashMap; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.ImmutableList; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; @@ -608,7 +608,7 @@ private RowExpression rowExpression(String expression) Map types = sourceLayout.entrySet().stream() .collect(toMap(Map.Entry::getValue, entry -> symbolTypes.get(entry.getKey()))); - IdentityLinkedHashMap expressionTypes = getExpressionTypesFromInput(TEST_SESSION, metadata, SQL_PARSER, types, inputReferenceExpression, emptyList()); + Map, Type> expressionTypes = getExpressionTypesFromInput(TEST_SESSION, metadata, SQL_PARSER, types, inputReferenceExpression, emptyList()); return SqlToRowExpressionTranslator.translate(inputReferenceExpression, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), TEST_SESSION, true); } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java index 1fd699bc2b884..b8fa1e9153d0e 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java @@ -20,7 +20,9 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.BooleanType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.SemanticErrorCode; @@ -49,6 +51,7 @@ import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; @@ -59,6 +62,7 @@ import static com.facebook.presto.type.TypeJsonUtils.appendToBlockBuilder; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.StructuralTestUtil.arrayBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.NaN; import static java.lang.Double.POSITIVE_INFINITY; @@ -189,6 +193,8 @@ public void testJsonToArray() assertFunction("CAST(JSON '[1, null, 3]' AS ARRAY)", new ArrayType(BIGINT), asList(1L, null, 3L)); assertFunction("CAST(JSON '[1, 2.0, 3]' AS ARRAY)", new ArrayType(DOUBLE), ImmutableList.of(1.0, 2.0, 3.0)); assertFunction("CAST(JSON '[1.0, 2.5, 3.0]' AS ARRAY)", new ArrayType(DOUBLE), ImmutableList.of(1.0, 2.5, 3.0)); + assertFunction("CAST(JSON '[1, 2.5, 3]' AS ARRAY)", new ArrayType(REAL), ImmutableList.of(1.0f, 2.5f, 3.0f)); + assertFunction("CAST(JSON '[-1, null, -3]' AS ARRAY)", new ArrayType(REAL), asList(-1.0f, null, -3.0f)); assertFunction("CAST(JSON '[\"puppies\", \"kittens\"]' AS ARRAY)", new ArrayType(VARCHAR), ImmutableList.of("puppies", "kittens")); assertFunction("CAST(JSON '[true, false]' AS ARRAY)", new ArrayType(BOOLEAN), ImmutableList.of(true, false)); assertFunction("CAST(JSON '[[1], [null]]' AS ARRAY>)", new ArrayType(new ArrayType(BIGINT)), asList(asList(1L), asList((Long) null))); @@ -932,6 +938,40 @@ public void testArrayRemove() assertFunction("ARRAY_REMOVE(ARRAY [ARRAY ['foo'], ARRAY ['bar'], ARRAY ['baz']], ARRAY ['bar'])", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(ImmutableList.of("foo"), ImmutableList.of("baz"))); } + @Test + public void testRepeat() + throws Exception + { + // concrete values + assertFunction("REPEAT(1, 5)", new ArrayType(INTEGER), ImmutableList.of(1, 1, 1, 1, 1)); + assertFunction("REPEAT('varchar', 3)", new ArrayType(createVarcharType(7)), ImmutableList.of("varchar", "varchar", "varchar")); + assertFunction("REPEAT(true, 1)", new ArrayType(BOOLEAN), ImmutableList.of(true)); + assertFunction("REPEAT(0.5, 4)", new ArrayType(DOUBLE), ImmutableList.of(0.5, 0.5, 0.5, 0.5)); + assertFunction("REPEAT(array[1], 4)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1), ImmutableList.of(1), ImmutableList.of(1), ImmutableList.of(1))); + + // null values + assertFunction("REPEAT(null, 4)", new ArrayType(UNKNOWN), asList(null, null, null, null)); + assertFunction("REPEAT(cast(null as bigint), 4)", new ArrayType(BIGINT), asList(null, null, null, null)); + assertFunction("REPEAT(cast(null as double), 4)", new ArrayType(DOUBLE), asList(null, null, null, null)); + assertFunction("REPEAT(cast(null as varchar), 4)", new ArrayType(VARCHAR), asList(null, null, null, null)); + assertFunction("REPEAT(cast(null as boolean), 4)", new ArrayType(BOOLEAN), asList(null, null, null, null)); + assertFunction("REPEAT(cast(null as array(boolean)), 4)", new ArrayType(new ArrayType(BOOLEAN)), asList(null, null, null, null)); + + // 0 counts + assertFunction("REPEAT(cast(null as bigint), 0)", new ArrayType(BIGINT), ImmutableList.of()); + assertFunction("REPEAT(1, 0)", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("REPEAT('varchar', 0)", new ArrayType(createVarcharType(7)), ImmutableList.of()); + assertFunction("REPEAT(true, 0)", new ArrayType(BOOLEAN), ImmutableList.of()); + assertFunction("REPEAT(0.5, 0)", new ArrayType(DOUBLE), ImmutableList.of()); + assertFunction("REPEAT(array[1], 0)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of()); + + // illegal inputs + assertInvalidFunction("REPEAT(2, -1)", INVALID_FUNCTION_ARGUMENT); + assertInvalidFunction("REPEAT(1, 1000000)", INVALID_FUNCTION_ARGUMENT); + assertInvalidFunction("REPEAT('loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooongvarchar', 9999)", INVALID_FUNCTION_ARGUMENT); + assertInvalidFunction("REPEAT(array[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 9999)", INVALID_FUNCTION_ARGUMENT); + } + @Test public void testSequence() throws Exception @@ -958,6 +998,7 @@ public void testSequence() // failure modes assertInvalidFunction("SEQUENCE(2, -1, 1)", INVALID_FUNCTION_ARGUMENT); assertInvalidFunction("SEQUENCE(-1, -10, 1)", INVALID_FUNCTION_ARGUMENT); + assertInvalidFunction("SEQUENCE(1, 1000000)", INVALID_FUNCTION_ARGUMENT); } @Test @@ -977,6 +1018,7 @@ public void testSequenceDateTimeDayToSecond() // failure modes assertInvalidFunction("SEQUENCE(date '2016-04-12', date '2016-04-14', interval '-1' day)", INVALID_FUNCTION_ARGUMENT); assertInvalidFunction("SEQUENCE(date '2016-04-14', date '2016-04-12', interval '1' day)", INVALID_FUNCTION_ARGUMENT); + assertInvalidFunction("SEQUENCE(date '2000-04-14', date '2030-04-12', interval '1' day)", INVALID_FUNCTION_ARGUMENT); } @Test @@ -1001,6 +1043,7 @@ public void testSequenceDateTimeYearToMonth() // failure modes assertInvalidFunction("SEQUENCE(date '2016-06-12', date '2016-04-12', interval '1' month)", INVALID_FUNCTION_ARGUMENT); assertInvalidFunction("SEQUENCE(date '2016-04-12', date '2016-06-12', interval '-1' month)", INVALID_FUNCTION_ARGUMENT); + assertInvalidFunction("SEQUENCE(date '2000-04-12', date '3000-06-12', interval '1' month)", INVALID_FUNCTION_ARGUMENT); } @Override @@ -1053,9 +1096,9 @@ public void testFlatten() assertFunction("flatten(ARRAY [NULL, ARRAY [ARRAY [5, 6], ARRAY [7, 8]]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(5, 6), ImmutableList.of(7, 8))); // MAP Tests - assertFunction("flatten(ARRAY [ARRAY [MAP (ARRAY [1, 2], ARRAY [1, 2])], ARRAY [MAP (ARRAY [3, 4], ARRAY [3, 4])]])", new ArrayType(new MapType(INTEGER, INTEGER)), ImmutableList.of(ImmutableMap.of(1, 1, 2, 2), ImmutableMap.of(3, 3, 4, 4))); - assertFunction("flatten(ARRAY [ARRAY [MAP (ARRAY [1, 2], ARRAY [1, 2])], NULL])", new ArrayType(new MapType(INTEGER, INTEGER)), ImmutableList.of(ImmutableMap.of(1, 1, 2, 2))); - assertFunction("flatten(ARRAY [NULL, ARRAY [MAP (ARRAY [3, 4], ARRAY [3, 4])]])", new ArrayType(new MapType(INTEGER, INTEGER)), ImmutableList.of(ImmutableMap.of(3, 3, 4, 4))); + assertFunction("flatten(ARRAY [ARRAY [MAP (ARRAY [1, 2], ARRAY [1, 2])], ARRAY [MAP (ARRAY [3, 4], ARRAY [3, 4])]])", new ArrayType(mapType(INTEGER, INTEGER)), ImmutableList.of(ImmutableMap.of(1, 1, 2, 2), ImmutableMap.of(3, 3, 4, 4))); + assertFunction("flatten(ARRAY [ARRAY [MAP (ARRAY [1, 2], ARRAY [1, 2])], NULL])", new ArrayType(mapType(INTEGER, INTEGER)), ImmutableList.of(ImmutableMap.of(1, 1, 2, 2))); + assertFunction("flatten(ARRAY [NULL, ARRAY [MAP (ARRAY [3, 4], ARRAY [3, 4])]])", new ArrayType(mapType(INTEGER, INTEGER)), ImmutableList.of(ImmutableMap.of(3, 3, 4, 4))); } @Test @@ -1065,7 +1108,7 @@ public void testArrayHashOperator() assertArrayHashOperator("ARRAY[true, false]", BOOLEAN, ImmutableList.of(true, false)); // test with ARRAY[ MAP( ARRAY[1], ARRAY[2] ) ] - MapType mapType = new MapType(INTEGER, INTEGER); + MapType mapType = mapType(INTEGER, INTEGER); BlockBuilder mapBuilder = new InterleavedBlockBuilder(ImmutableList.of(INTEGER, INTEGER), new BlockBuilderStatus(), 2); INTEGER.writeLong(mapBuilder, 1); INTEGER.writeLong(mapBuilder, 2); diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestBigintVarcharMapType.java b/presto-main/src/test/java/com/facebook/presto/type/TestBigintVarcharMapType.java index 53a70c92c3fa8..098667d676370 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestBigintVarcharMapType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestBigintVarcharMapType.java @@ -22,16 +22,16 @@ import java.util.Map; import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestBigintVarcharMapType extends AbstractTestType { public TestBigintVarcharMapType() { - super(new TypeRegistry().getType(parseTypeSignature("map(bigint,varchar)")), Map.class, createTestBlock(new TypeRegistry().getType(parseTypeSignature("map(bigint,varchar)")))); + super(mapType(BIGINT, VARCHAR), Map.class, createTestBlock(mapType(BIGINT, VARCHAR))); } public static Block createTestBlock(Type mapType) diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestBooleanOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestBooleanOperators.java index 1cc8e6bf46db0..73c32a3c9760e 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestBooleanOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestBooleanOperators.java @@ -17,6 +17,7 @@ import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; public class TestBooleanOperators @@ -112,6 +113,14 @@ public void testBetween() assertFunction("false BETWEEN false AND false", BOOLEAN, true); } + @Test + public void testCastToReal() + throws Exception + { + assertFunction("cast(true as real)", REAL, 1.0f); + assertFunction("cast(false as real)", REAL, 0.0f); + } + @Test public void testCastToVarchar() throws Exception diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestCharacterStringCasts.java b/presto-main/src/test/java/com/facebook/presto/type/TestCharacterStringCasts.java index 39b206562a729..3394df93d9ec4 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestCharacterStringCasts.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestCharacterStringCasts.java @@ -61,9 +61,85 @@ public void testCharToVarcharCast() @Test public void testVarcharToCharSaturatedFloorCast() { - assertEquals(varcharToCharSaturatedFloorCast(10L, utf8Slice("1234567890")), utf8Slice("1234567890")); - assertEquals(varcharToCharSaturatedFloorCast(10L, utf8Slice("123456789")), utf8Slice("12345678")); - assertEquals(varcharToCharSaturatedFloorCast(10L, utf8Slice("12345678901")), utf8Slice("1234567890")); - assertEquals(varcharToCharSaturatedFloorCast(10L, utf8Slice("")), utf8Slice("")); + String nonBmpCharacter = new String(Character.toChars(0x1F50D)); + String nonBmpCharacterMinusOne = new String(Character.toChars(0x1F50C)); + String maxCodePoint = new String(Character.toChars(Character.MAX_CODE_POINT)); + String codePointBeforeSpace = new String(Character.toChars(' ' - 1)); + + // Truncation + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("12345")), + utf8Slice("1234")); + + // Size fits, preserved + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("1234")), + utf8Slice("1234")); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("123" + nonBmpCharacter)), + utf8Slice("123" + nonBmpCharacter)); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("12" + nonBmpCharacter + "3")), + utf8Slice("12" + nonBmpCharacter + "3")); + + // Size fits, preserved except char(4) representation has trailing spaces removed + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("123 ")), + utf8Slice("123")); + + // Too short, casted back would be padded with ' ' and thus made greater (VarcharOperators.lessThan), so last character needs decrementing + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("123")), + utf8Slice("122" + maxCodePoint)); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("12 ")), + utf8Slice("12" + codePointBeforeSpace + maxCodePoint)); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("1 ")), + utf8Slice("1 " + codePointBeforeSpace + maxCodePoint)); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice(" ")), + utf8Slice(codePointBeforeSpace + maxCodePoint + maxCodePoint + maxCodePoint)); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("12" + nonBmpCharacter)), + utf8Slice("12" + nonBmpCharacterMinusOne + maxCodePoint)); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("1" + nonBmpCharacter + "3")), + utf8Slice("1" + nonBmpCharacter + "2" + maxCodePoint)); + + // Too short, casted back would be padded with ' ' and thus made greater (VarcharOperators.lessThan), previous to last needs decrementing since last is \0 + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("12\0")), + utf8Slice("11" + maxCodePoint + maxCodePoint)); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("1\0")), + utf8Slice("0" + maxCodePoint + maxCodePoint + maxCodePoint)); + + // Smaller than any char(4) casted back to varchar, so the result is lowest char(4) possible + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("\0")), + utf8Slice("\0\0\0\0")); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("\0\0")), + utf8Slice("\0\0\0\0")); + assertEquals(varcharToCharSaturatedFloorCast( + 4L, + utf8Slice("")), + utf8Slice("\0\0\0\0")); } } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestIntegerVarcharMapType.java b/presto-main/src/test/java/com/facebook/presto/type/TestIntegerVarcharMapType.java index 74ace03d3a1a6..97133240b8229 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestIntegerVarcharMapType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestIntegerVarcharMapType.java @@ -22,16 +22,16 @@ import java.util.Map; import static com.facebook.presto.spi.type.IntegerType.INTEGER; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestIntegerVarcharMapType extends AbstractTestType { public TestIntegerVarcharMapType() { - super(new TypeRegistry().getType(parseTypeSignature("map(integer,varchar)")), Map.class, createTestBlock(new TypeRegistry().getType(parseTypeSignature("map(integer,varchar)")))); + super(mapType(INTEGER, VARCHAR), Map.class, createTestBlock(mapType(INTEGER, VARCHAR))); } public static Block createTestBlock(Type mapType) diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestJsonOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestJsonOperators.java index 2ce4dfa685a09..03619181ce3f4 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestJsonOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestJsonOperators.java @@ -23,6 +23,7 @@ import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DecimalType.createDecimalType; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.type.JsonType.JSON; import static java.lang.Double.NEGATIVE_INFINITY; @@ -72,8 +73,12 @@ public void testCastFromIntegrals() { assertFunction("cast(cast (null as integer) as JSON)", JSON, null); assertFunction("cast(cast (null as bigint) as JSON)", JSON, null); + assertFunction("cast(cast (null as smallint) as JSON)", JSON, null); + assertFunction("cast(cast (null as tinyint) as JSON)", JSON, null); assertFunction("cast(128 as JSON)", JSON, "128"); assertFunction("cast(BIGINT '128' as JSON)", JSON, "128"); + assertFunction("cast(SMALLINT '128' as JSON)", JSON, "128"); + assertFunction("cast(TINYINT '127' as JSON)", JSON, "127"); } @Test @@ -107,13 +112,52 @@ public void testCastToDouble() public void testCastFromDouble() throws Exception { - assertFunction("cast(cast (null as double) as JSON)", JSON, null); + assertFunction("cast(cast(null as double) as JSON)", JSON, null); assertFunction("cast(3.14 as JSON)", JSON, "3.14"); assertFunction("cast(nan() as JSON)", JSON, "\"NaN\""); assertFunction("cast(infinity() as JSON)", JSON, "\"Infinity\""); assertFunction("cast(-infinity() as JSON)", JSON, "\"-Infinity\""); } + @Test + public void testCastFromReal() + throws Exception + { + assertFunction("cast(cast(null as REAL) as JSON)", JSON, null); + assertFunction("cast(REAL '3.14' as JSON)", JSON, "3.14"); + assertFunction("cast(cast(nan() as REAL) as JSON)", JSON, "\"NaN\""); + assertFunction("cast(cast(infinity() as REAL) as JSON)", JSON, "\"Infinity\""); + assertFunction("cast(cast(-infinity() as REAL) as JSON)", JSON, "\"-Infinity\""); + } + + @Test + public void testCastToReal() + throws Exception + { + assertFunction("cast(JSON 'null' as REAL)", REAL, null); + assertFunction("cast(JSON '-128' as REAL)", REAL, -128.0f); + assertFunction("cast(JSON '128' as REAL)", REAL, 128.0f); + assertFunction("cast(JSON '12345678901234567890' as REAL)", REAL, 1.2345679e19f); + assertFunction("cast(JSON '128.9' as REAL)", REAL, 128.9f); + assertFunction("cast(JSON '1e-46' as REAL)", REAL, 0.0f); // smaller than minimum subnormal positive + assertFunction("cast(JSON '1e39' as REAL)", REAL, Float.POSITIVE_INFINITY); // overflow + assertFunction("cast(JSON '-1e39' as REAL)", REAL, Float.NEGATIVE_INFINITY); // underflow + assertFunction("cast(JSON 'true' as REAL)", REAL, 1.0f); + assertFunction("cast(JSON 'false' as REAL)", REAL, 0.0f); + assertFunction("cast(JSON '\"128\"' as REAL)", REAL, 128.0f); + assertFunction("cast(JSON '\"12345678901234567890\"' as REAL)", REAL, 1.2345679e19f); + assertFunction("cast(JSON '\"128.9\"' as REAL)", REAL, 128.9f); + assertFunction("cast(JSON '\"NaN\"' as REAL)", REAL, Float.NaN); + assertFunction("cast(JSON '\"Infinity\"' as REAL)", REAL, Float.POSITIVE_INFINITY); + assertFunction("cast(JSON '\"-Infinity\"' as REAL)", REAL, Float.NEGATIVE_INFINITY); + assertInvalidFunction("cast(JSON '\"true\"' as REAL)", INVALID_CAST_ARGUMENT); + + assertFunction("cast(JSON ' 128.9' as REAL)", REAL, 128.9f); // leading space + + assertFunction("cast(json_extract('{\"x\":1.23}', '$.x') as REAL)", REAL, 1.23f); + assertInvalidCast("cast(JSON '{ \"x\" : 123}' as REAL)"); + } + @Test public void testCastToDecimal() throws Exception diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java index da6246d1d2f80..fc1ce936a6442 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestMapOperators.java @@ -21,7 +21,8 @@ import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.SqlVarbinary; @@ -44,6 +45,7 @@ import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DecimalType.createDecimalType; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.RealType.REAL; @@ -57,7 +59,9 @@ import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.StructuralTestUtil.arrayBlockOf; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.builder; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Double.doubleToLongBits; import static java.lang.String.format; @@ -114,26 +118,26 @@ public void testStackRepresentation() public void testConstructor() throws Exception { - assertFunction("MAP(ARRAY ['1','3'], ARRAY [2,4])", new MapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 2, "3", 4)); + assertFunction("MAP(ARRAY ['1','3'], ARRAY [2,4])", mapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 2, "3", 4)); Map map = new HashMap<>(); map.put(1, 2); map.put(3, null); - assertFunction("MAP(ARRAY [1, 3], ARRAY[2, NULL])", new MapType(INTEGER, INTEGER), map); - assertFunction("MAP(ARRAY [1, 3], ARRAY [2.0, 4.0])", new MapType(INTEGER, DOUBLE), ImmutableMap.of(1, 2.0, 3, 4.0)); + assertFunction("MAP(ARRAY [1, 3], ARRAY[2, NULL])", mapType(INTEGER, INTEGER), map); + assertFunction("MAP(ARRAY [1, 3], ARRAY [2.0, 4.0])", mapType(INTEGER, DOUBLE), ImmutableMap.of(1, 2.0, 3, 4.0)); assertFunction("MAP(ARRAY[1.0, 2.0], ARRAY[ ARRAY[1, 2], ARRAY[3]])", - new MapType(DOUBLE, new ArrayType(INTEGER)), + mapType(DOUBLE, new ArrayType(INTEGER)), ImmutableMap.of(1.0, ImmutableList.of(1, 2), 2.0, ImmutableList.of(3))); assertFunction("MAP(ARRAY[1.0, 2.0], ARRAY[ ARRAY[BIGINT '1', BIGINT '2'], ARRAY[ BIGINT '3' ]])", - new MapType(DOUBLE, new ArrayType(BIGINT)), + mapType(DOUBLE, new ArrayType(BIGINT)), ImmutableMap.of(1.0, ImmutableList.of(1L, 2L), 2.0, ImmutableList.of(3L))); - assertFunction("MAP(ARRAY['puppies'], ARRAY['kittens'])", new MapType(createVarcharType(7), createVarcharType(7)), ImmutableMap.of("puppies", "kittens")); - assertFunction("MAP(ARRAY[TRUE, FALSE], ARRAY[2,4])", new MapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 2, false, 4)); - assertFunction("MAP(ARRAY['1', '100'], ARRAY[from_unixtime(1), from_unixtime(100)])", new MapType(createVarcharType(3), TIMESTAMP), ImmutableMap.of( + assertFunction("MAP(ARRAY['puppies'], ARRAY['kittens'])", mapType(createVarcharType(7), createVarcharType(7)), ImmutableMap.of("puppies", "kittens")); + assertFunction("MAP(ARRAY[TRUE, FALSE], ARRAY[2,4])", mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 2, false, 4)); + assertFunction("MAP(ARRAY['1', '100'], ARRAY[from_unixtime(1), from_unixtime(100)])", mapType(createVarcharType(3), TIMESTAMP), ImmutableMap.of( "1", new SqlTimestamp(1000, TEST_SESSION.getTimeZoneKey()), "100", new SqlTimestamp(100_000, TEST_SESSION.getTimeZoneKey()))); - assertFunction("MAP(ARRAY[from_unixtime(1), from_unixtime(100)], ARRAY[1.0, 100.0])", new MapType(TIMESTAMP, DOUBLE), ImmutableMap.of( + assertFunction("MAP(ARRAY[from_unixtime(1), from_unixtime(100)], ARRAY[1.0, 100.0])", mapType(TIMESTAMP, DOUBLE), ImmutableMap.of( new SqlTimestamp(1000, TEST_SESSION.getTimeZoneKey()), 1.0, new SqlTimestamp(100_000, TEST_SESSION.getTimeZoneKey()), @@ -145,7 +149,7 @@ public void testConstructor() @Test public void testEmptyMapConstructor() { - assertFunction("MAP()", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("MAP()", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); } @Test @@ -284,26 +288,29 @@ public void testJsonToMap() throws Exception { assertFunction("CAST(JSON '{\"1\":2, \"3\": 4}' AS MAP)", - new MapType(BIGINT, BIGINT), + mapType(BIGINT, BIGINT), ImmutableMap.of(1L, 2L, 3L, 4L)); assertFunction("CAST(JSON '{\"1\":2.0, \"3\": 4.0}' AS MAP)", - new MapType(BIGINT, DOUBLE), + mapType(BIGINT, DOUBLE), ImmutableMap.of(1L, 2.0, 3L, 4.0)); + assertFunction("CAST(JSON '{\"1\":2.0, \"3\": 4.0}' AS MAP)", + mapType(BIGINT, REAL), + ImmutableMap.of(1L, 2.0f, 3L, 4.0f)); assertFunction("CAST(JSON '{\"1\":[2, 3], \"4\": [5]}' AS MAP>)", - new MapType(BIGINT, new ArrayType(BIGINT)), + mapType(BIGINT, new ArrayType(BIGINT)), ImmutableMap.of(1L, ImmutableList.of(2L, 3L), 4L, ImmutableList.of(5L))); assertFunction("CAST(JSON '{\"puppies\":\"kittens\"}' AS MAP)", - new MapType(VARCHAR, VARCHAR), + mapType(VARCHAR, VARCHAR), ImmutableMap.of("puppies", "kittens")); assertFunction("CAST(JSON '{\"true\":\"kittens\"}' AS MAP)", - new MapType(BOOLEAN, VARCHAR), + mapType(BOOLEAN, VARCHAR), ImmutableMap.of(true, "kittens")); assertFunction("CAST(JSON 'null' AS MAP)", - new MapType(BOOLEAN, VARCHAR), + mapType(BOOLEAN, VARCHAR), null); assertFunction("CAST(JSON '{\"k1\": 5, \"k2\":[1, 2, 3], \"k3\":\"e\", \"k4\":{\"a\": \"b\"}, \"k5\":null, \"k6\":\"null\", \"k7\":[null]}' AS MAP)", - new MapType(VARCHAR, JSON), - ImmutableMap.builder() + mapType(VARCHAR, JSON), + builder() .put("k1", "5") .put("k2", "[1,2,3]") .put("k3", "\"e\"") @@ -318,11 +325,11 @@ public void testJsonToMap() // The second test should never happen in real life because valid json in presto requires natural key ordering. // However, it is added to make sure that the order in the first test is not a coincidence. assertFunction("CAST(JSON '{\"k1\": {\"1klmnopq\":1, \"2klmnopq\":2, \"3klmnopq\":3, \"4klmnopq\":4, \"5klmnopq\":5, \"6klmnopq\":6, \"7klmnopq\":7}}' AS MAP)", - new MapType(VARCHAR, JSON), + mapType(VARCHAR, JSON), ImmutableMap.of("k1", "{\"1klmnopq\":1,\"2klmnopq\":2,\"3klmnopq\":3,\"4klmnopq\":4,\"5klmnopq\":5,\"6klmnopq\":6,\"7klmnopq\":7}") ); assertFunction("CAST(unchecked_to_json('{\"k1\": {\"7klmnopq\":7, \"6klmnopq\":6, \"5klmnopq\":5, \"4klmnopq\":4, \"3klmnopq\":3, \"2klmnopq\":2, \"1klmnopq\":1}}') AS MAP)", - new MapType(VARCHAR, JSON), + mapType(VARCHAR, JSON), ImmutableMap.of("k1", "{\"7klmnopq\":7,\"6klmnopq\":6,\"5klmnopq\":5,\"4klmnopq\":4,\"3klmnopq\":3,\"2klmnopq\":2,\"1klmnopq\":1}") ); @@ -375,6 +382,8 @@ public void testSubscript() assertFunction("MAP(ARRAY[from_unixtime(1), from_unixtime(100)], ARRAY[1.0, 100.0])[from_unixtime(1)]", DOUBLE, 1.0); assertInvalidFunction("MAP(ARRAY [BIGINT '1'], ARRAY [BIGINT '2'])[3]", "Key not present in map: 3"); assertInvalidFunction("MAP(ARRAY ['hi'], ARRAY [2])['missing']", "Key not present in map: missing"); + assertFunction("MAP(ARRAY[array[1,1]], ARRAY['a'])[ARRAY[1,1]]", createVarcharType(1), "a"); + assertFunction("MAP(ARRAY[('a', 'b')], ARRAY[ARRAY[100, 200]])[('a', 'b')]", new ArrayType(INTEGER), ImmutableList.of(100, 200)); } @Test @@ -524,58 +533,58 @@ public void testDistinctFrom() public void testMapConcat() throws Exception { - assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE], ARRAY [1]), MAP (CAST(ARRAY [] AS ARRAY(BOOLEAN)), CAST(ARRAY [] AS ARRAY(INTEGER))))", new MapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 1)); + assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE], ARRAY [1]), MAP (CAST(ARRAY [] AS ARRAY(BOOLEAN)), CAST(ARRAY [] AS ARRAY(INTEGER))))", mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 1)); // Tests - assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE], ARRAY [1]), MAP (ARRAY [TRUE, FALSE], ARRAY [10, 20]))", new MapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 20)); - assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE, FALSE], ARRAY [1, 2]), MAP (ARRAY [TRUE, FALSE], ARRAY [10, 20]))", new MapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 20)); - assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE, FALSE], ARRAY [1, 2]), MAP (ARRAY [TRUE], ARRAY [10]))", new MapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 2)); + assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE], ARRAY [1]), MAP (ARRAY [TRUE, FALSE], ARRAY [10, 20]))", mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 20)); + assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE, FALSE], ARRAY [1, 2]), MAP (ARRAY [TRUE, FALSE], ARRAY [10, 20]))", mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 20)); + assertFunction("MAP_CONCAT(MAP (ARRAY [TRUE, FALSE], ARRAY [1, 2]), MAP (ARRAY [TRUE], ARRAY [10]))", mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 2)); // Tests - assertFunction("MAP_CONCAT(MAP (ARRAY ['1', '2', '3'], ARRAY [1, 2, 3]), MAP (ARRAY ['1', '2', '3', '4'], ARRAY [10, 20, 30, 40]))", new MapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 10, "2", 20, "3", 30, "4", 40)); - assertFunction("MAP_CONCAT(MAP (ARRAY ['1', '2', '3', '4'], ARRAY [1, 2, 3, 4]), MAP (ARRAY ['1', '2', '3', '4'], ARRAY [10, 20, 30, 40]))", new MapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 10, "2", 20, "3", 30, "4", 40)); - assertFunction("MAP_CONCAT(MAP (ARRAY ['1', '2', '3', '4'], ARRAY [1, 2, 3, 4]), MAP (ARRAY ['1', '2', '3'], ARRAY [10, 20, 30]))", new MapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 10, "2", 20, "3", 30, "4", 4)); + assertFunction("MAP_CONCAT(MAP (ARRAY ['1', '2', '3'], ARRAY [1, 2, 3]), MAP (ARRAY ['1', '2', '3', '4'], ARRAY [10, 20, 30, 40]))", mapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 10, "2", 20, "3", 30, "4", 40)); + assertFunction("MAP_CONCAT(MAP (ARRAY ['1', '2', '3', '4'], ARRAY [1, 2, 3, 4]), MAP (ARRAY ['1', '2', '3', '4'], ARRAY [10, 20, 30, 40]))", mapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 10, "2", 20, "3", 30, "4", 40)); + assertFunction("MAP_CONCAT(MAP (ARRAY ['1', '2', '3', '4'], ARRAY [1, 2, 3, 4]), MAP (ARRAY ['1', '2', '3'], ARRAY [10, 20, 30]))", mapType(createVarcharType(1), INTEGER), ImmutableMap.of("1", 10, "2", 20, "3", 30, "4", 4)); // > Tests - assertFunction("MAP_CONCAT(MAP (ARRAY [1, 2, 3], ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0]]), MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [10.0], ARRAY [20.0], ARRAY [30.0], ARRAY [40.0]]))", new MapType(INTEGER, new ArrayType(DOUBLE)), ImmutableMap.of(1, ImmutableList.of(10.0), 2, ImmutableList.of(20.0), 3, ImmutableList.of(30.0), 4, ImmutableList.of(40.0))); - assertFunction("MAP_CONCAT(MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0], ARRAY [4.0]]), MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [10.0], ARRAY [20.0], ARRAY [30.0], ARRAY [40.0]]))", new MapType(INTEGER, new ArrayType(DOUBLE)), ImmutableMap.of(1, ImmutableList.of(10.0), 2, ImmutableList.of(20.0), 3, ImmutableList.of(30.0), 4, ImmutableList.of(40.0))); - assertFunction("MAP_CONCAT(MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0], ARRAY [4.0]]), MAP (ARRAY [1, 2, 3], ARRAY [ARRAY [10.0], ARRAY [20.0], ARRAY [30.0]]))", new MapType(INTEGER, new ArrayType(DOUBLE)), ImmutableMap.of(1, ImmutableList.of(10.0), 2, ImmutableList.of(20.0), 3, ImmutableList.of(30.0), 4, ImmutableList.of(4.0))); + assertFunction("MAP_CONCAT(MAP (ARRAY [1, 2, 3], ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0]]), MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [10.0], ARRAY [20.0], ARRAY [30.0], ARRAY [40.0]]))", mapType(INTEGER, new ArrayType(DOUBLE)), ImmutableMap.of(1, ImmutableList.of(10.0), 2, ImmutableList.of(20.0), 3, ImmutableList.of(30.0), 4, ImmutableList.of(40.0))); + assertFunction("MAP_CONCAT(MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0], ARRAY [4.0]]), MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [10.0], ARRAY [20.0], ARRAY [30.0], ARRAY [40.0]]))", mapType(INTEGER, new ArrayType(DOUBLE)), ImmutableMap.of(1, ImmutableList.of(10.0), 2, ImmutableList.of(20.0), 3, ImmutableList.of(30.0), 4, ImmutableList.of(40.0))); + assertFunction("MAP_CONCAT(MAP (ARRAY [1, 2, 3, 4], ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0], ARRAY [4.0]]), MAP (ARRAY [1, 2, 3], ARRAY [ARRAY [10.0], ARRAY [20.0], ARRAY [30.0]]))", mapType(INTEGER, new ArrayType(DOUBLE)), ImmutableMap.of(1, ImmutableList.of(10.0), 2, ImmutableList.of(20.0), 3, ImmutableList.of(30.0), 4, ImmutableList.of(4.0))); // , VARCHAR> Tests assertFunction( "MAP_CONCAT(MAP (ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0]], ARRAY ['1', '2', '3']), MAP (ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0], ARRAY [4.0]], ARRAY ['10', '20', '30', '40']))", - new MapType(new ArrayType(DOUBLE), createVarcharType(2)), + mapType(new ArrayType(DOUBLE), createVarcharType(2)), ImmutableMap.of(ImmutableList.of(1.0), "10", ImmutableList.of(2.0), "20", ImmutableList.of(3.0), "30", ImmutableList.of(4.0), "40")); assertFunction( "MAP_CONCAT(MAP (ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0]], ARRAY ['1', '2', '3']), MAP (ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0], ARRAY [4.0]], ARRAY ['10', '20', '30', '40']))", - new MapType(new ArrayType(DOUBLE), createVarcharType(2)), + mapType(new ArrayType(DOUBLE), createVarcharType(2)), ImmutableMap.of(ImmutableList.of(1.0), "10", ImmutableList.of(2.0), "20", ImmutableList.of(3.0), "30", ImmutableList.of(4.0), "40")); assertFunction("MAP_CONCAT(MAP (ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0], ARRAY [4.0]], ARRAY ['1', '2', '3', '4']), MAP (ARRAY [ARRAY [1.0], ARRAY [2.0], ARRAY [3.0]], ARRAY ['10', '20', '30']))", - new MapType(new ArrayType(DOUBLE), createVarcharType(2)), + mapType(new ArrayType(DOUBLE), createVarcharType(2)), ImmutableMap.of(ImmutableList.of(1.0), "10", ImmutableList.of(2.0), "20", ImmutableList.of(3.0), "30", ImmutableList.of(4.0), "4")); // Tests for concatenating multiple maps - assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), NULL, MAP(ARRAY[3], ARRAY[-3]))", new MapType(INTEGER, INTEGER), null); - assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(ARRAY[2], ARRAY[-2]), MAP(ARRAY[3], ARRAY[-3]))", new MapType(INTEGER, INTEGER), ImmutableMap.of(1, -1, 2, -2, 3, -3)); - assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(ARRAY[1], ARRAY[-2]), MAP(ARRAY[1], ARRAY[-3]))", new MapType(INTEGER, INTEGER), ImmutableMap.of(1, -3)); - assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(ARRAY[], ARRAY[]), MAP(ARRAY[3], ARRAY[-3]))", new MapType(INTEGER, INTEGER), ImmutableMap.of(1, -1, 3, -3)); - assertFunction("MAP_CONCAT(MAP(ARRAY[], ARRAY[]), MAP(ARRAY['a_string'], ARRAY['b_string']), cast(MAP(ARRAY[], ARRAY[]) AS MAP(VARCHAR, VARCHAR)))", new MapType(VARCHAR, VARCHAR), ImmutableMap.of("a_string", "b_string")); - assertFunction("MAP_CONCAT(MAP(ARRAY[], ARRAY[]), MAP(ARRAY[], ARRAY[]), MAP(ARRAY[], ARRAY[]))", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("MAP_CONCAT(MAP(), MAP(), MAP())", new MapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); - assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(), MAP(ARRAY[3], ARRAY[-3]))", new MapType(INTEGER, INTEGER), ImmutableMap.of(1, -1, 3, -3)); - assertFunction("MAP_CONCAT(MAP(ARRAY[TRUE], ARRAY[1]), MAP(ARRAY[TRUE, FALSE], ARRAY[10, 20]), MAP(ARRAY[FALSE], ARRAY[0]))", new MapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 0)); + assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), NULL, MAP(ARRAY[3], ARRAY[-3]))", mapType(INTEGER, INTEGER), null); + assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(ARRAY[2], ARRAY[-2]), MAP(ARRAY[3], ARRAY[-3]))", mapType(INTEGER, INTEGER), ImmutableMap.of(1, -1, 2, -2, 3, -3)); + assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(ARRAY[1], ARRAY[-2]), MAP(ARRAY[1], ARRAY[-3]))", mapType(INTEGER, INTEGER), ImmutableMap.of(1, -3)); + assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(ARRAY[], ARRAY[]), MAP(ARRAY[3], ARRAY[-3]))", mapType(INTEGER, INTEGER), ImmutableMap.of(1, -1, 3, -3)); + assertFunction("MAP_CONCAT(MAP(ARRAY[], ARRAY[]), MAP(ARRAY['a_string'], ARRAY['b_string']), cast(MAP(ARRAY[], ARRAY[]) AS MAP(VARCHAR, VARCHAR)))", mapType(VARCHAR, VARCHAR), ImmutableMap.of("a_string", "b_string")); + assertFunction("MAP_CONCAT(MAP(ARRAY[], ARRAY[]), MAP(ARRAY[], ARRAY[]), MAP(ARRAY[], ARRAY[]))", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("MAP_CONCAT(MAP(), MAP(), MAP())", mapType(UNKNOWN, UNKNOWN), ImmutableMap.of()); + assertFunction("MAP_CONCAT(MAP(ARRAY[1], ARRAY[-1]), MAP(), MAP(ARRAY[3], ARRAY[-3]))", mapType(INTEGER, INTEGER), ImmutableMap.of(1, -1, 3, -3)); + assertFunction("MAP_CONCAT(MAP(ARRAY[TRUE], ARRAY[1]), MAP(ARRAY[TRUE, FALSE], ARRAY[10, 20]), MAP(ARRAY[FALSE], ARRAY[0]))", mapType(BOOLEAN, INTEGER), ImmutableMap.of(true, 10, false, 0)); } @Test public void testMapToMapCast() { - assertFunction("CAST(MAP(ARRAY['1', '100'], ARRAY[true, false]) AS MAP)", new MapType(VARCHAR, BIGINT), ImmutableMap.of("1", 1L, "100", 0L)); - assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[1,2]) AS MAP)", new MapType(BIGINT, BOOLEAN), ImmutableMap.of(1L, true, 2L, true)); - assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[array[1],array[2]]) AS MAP>)", new MapType(BIGINT, new ArrayType(BOOLEAN)), ImmutableMap.of(1L, ImmutableList.of(true), 2L, ImmutableList.of(true))); - assertFunction("CAST(MAP(ARRAY[1], ARRAY[MAP(ARRAY[1.0], ARRAY[false])]) AS MAP)", new MapType(VARCHAR, new MapType(BIGINT, BIGINT)), ImmutableMap.of("1", ImmutableMap.of(1L, 0L))); - assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[DATE '2016-01-02', DATE '2016-02-03']) AS MAP(bigint, varchar))", new MapType(BIGINT, VARCHAR), ImmutableMap.of(1L, "2016-01-02", 2L, "2016-02-03")); - assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[TIMESTAMP '2016-01-02 01:02:03', TIMESTAMP '2016-02-03 03:04:05']) AS MAP(bigint, varchar))", new MapType(BIGINT, VARCHAR), ImmutableMap.of(1L, "2016-01-02 01:02:03.000", 2L, "2016-02-03 03:04:05.000")); - assertFunction("CAST(MAP(ARRAY['123', '456'], ARRAY[1.23456, 2.34567]) AS MAP(integer, real))", new MapType(INTEGER, REAL), ImmutableMap.of(123, 1.23456F, 456, 2.34567F)); - assertFunction("CAST(MAP(ARRAY['123', '456'], ARRAY[1.23456, 2.34567]) AS MAP(smallint, decimal(6,5)))", new MapType(SMALLINT, DecimalType.createDecimalType(6, 5)), ImmutableMap.of((short) 123, SqlDecimal.of("1.23456"), (short) 456, SqlDecimal.of("2.34567"))); + assertFunction("CAST(MAP(ARRAY['1', '100'], ARRAY[true, false]) AS MAP)", mapType(VARCHAR, BIGINT), ImmutableMap.of("1", 1L, "100", 0L)); + assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[1,2]) AS MAP)", mapType(BIGINT, BOOLEAN), ImmutableMap.of(1L, true, 2L, true)); + assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[array[1],array[2]]) AS MAP>)", mapType(BIGINT, new ArrayType(BOOLEAN)), ImmutableMap.of(1L, ImmutableList.of(true), 2L, ImmutableList.of(true))); + assertFunction("CAST(MAP(ARRAY[1], ARRAY[MAP(ARRAY[1.0], ARRAY[false])]) AS MAP)", mapType(VARCHAR, mapType(BIGINT, BIGINT)), ImmutableMap.of("1", ImmutableMap.of(1L, 0L))); + assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[DATE '2016-01-02', DATE '2016-02-03']) AS MAP(bigint, varchar))", mapType(BIGINT, VARCHAR), ImmutableMap.of(1L, "2016-01-02", 2L, "2016-02-03")); + assertFunction("CAST(MAP(ARRAY[1,2], ARRAY[TIMESTAMP '2016-01-02 01:02:03', TIMESTAMP '2016-02-03 03:04:05']) AS MAP(bigint, varchar))", mapType(BIGINT, VARCHAR), ImmutableMap.of(1L, "2016-01-02 01:02:03.000", 2L, "2016-02-03 03:04:05.000")); + assertFunction("CAST(MAP(ARRAY['123', '456'], ARRAY[1.23456, 2.34567]) AS MAP(integer, real))", mapType(INTEGER, REAL), ImmutableMap.of(123, 1.23456F, 456, 2.34567F)); + assertFunction("CAST(MAP(ARRAY['123', '456'], ARRAY[1.23456, 2.34567]) AS MAP(smallint, decimal(6,5)))", mapType(SMALLINT, createDecimalType(6, 5)), ImmutableMap.of((short) 123, SqlDecimal.of("1.23456"), (short) 456, SqlDecimal.of("2.34567"))); // null values Map expected = new HashMap<>(); @@ -583,7 +592,7 @@ public void testMapToMapCast() expected.put(1L, null); expected.put(2L, null); expected.put(3L, 2.0); - assertFunction("CAST(MAP(ARRAY[0, 1, 2, 3], ARRAY[1,NULL, NULL, 2]) AS MAP)", new MapType(BIGINT, DOUBLE), expected); + assertFunction("CAST(MAP(ARRAY[0, 1, 2, 3], ARRAY[1,NULL, NULL, 2]) AS MAP)", mapType(BIGINT, DOUBLE), expected); assertInvalidCast("CAST(MAP(ARRAY[1, 2], ARRAY[6, 9]) AS MAP)", "duplicate keys"); } @@ -601,7 +610,7 @@ public void testMapHashOperator() private void assertMapHashOperator(String inputString, Type keyType, Type valueType, List elements) { checkArgument(elements.size() % 2 == 0, "the size of elements should be even number"); - MapType mapType = new MapType(keyType, valueType); + MapType mapType = mapType(keyType, valueType); BlockBuilder mapArrayBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 1); BlockBuilder mapBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueType), new BlockBuilderStatus(), elements.size()); for (int i = 0; i < elements.size(); i += 2) { diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestMapType.java b/presto-main/src/test/java/com/facebook/presto/type/TestMapType.java deleted file mode 100644 index 7a7783e10f697..0000000000000 --- a/presto-main/src/test/java/com/facebook/presto/type/TestMapType.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.type; - -import com.facebook.presto.spi.type.VarcharType; -import org.testng.annotations.Test; - -import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.VarcharType.VARCHAR; -import static org.testng.Assert.assertEquals; - -public class TestMapType -{ - @Test - public void testMapDisplayName() - { - MapType mapType = new MapType(BIGINT, VarcharType.createVarcharType(42)); - assertEquals(mapType.getDisplayName(), "map(bigint, varchar(42))"); - - mapType = new MapType(BIGINT, VARCHAR); - assertEquals(mapType.getDisplayName(), "map(bigint, varchar)"); - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java index a84cd2bd1daad..114e311084035 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestRowOperators.java @@ -18,6 +18,8 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.SemanticErrorCode; @@ -43,6 +45,7 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.type.JsonType.JSON; import static com.facebook.presto.type.TypeJsonUtils.appendToBlockBuilder; +import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Preconditions.checkArgument; import static java.lang.String.format; import static org.testng.Assert.assertEquals; @@ -134,7 +137,7 @@ public void testFieldAccessor() assertFunction("CAST(row(1, 2) AS ROW(col0 integer, col1 integer)).\"col1\"", INTEGER, 2); assertFunction("CAST(array[row(1, 2)] AS array(row(col0 integer, col1 integer)))[1].col1", INTEGER, 2); assertFunction("CAST(row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])) AS ROW(col0 boolean , col1 array(integer), col2 map(integer, double))).col1", new ArrayType(INTEGER), ImmutableList.of(1, 2)); - assertFunction("CAST(row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])) AS ROW(col0 boolean , col1 array(integer), col2 map(integer, double))).col2", new MapType(INTEGER, DOUBLE), ImmutableMap.of(1, 2.0, 3, 4.0)); + assertFunction("CAST(row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])) AS ROW(col0 boolean , col1 array(integer), col2 map(integer, double))).col2", mapType(INTEGER, DOUBLE), ImmutableMap.of(1, 2.0, 3, 4.0)); assertFunction("CAST(row(1.0, ARRAY[row(31, 4.1), row(32, 4.2)], row(3, 4.0)) AS ROW(col0 double, col1 array(row(col0 integer, col1 double)), col2 row(col0 integer, col1 double))).col1[2].col0", INTEGER, 32); // Using ROW constructor diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestSimpleRowType.java b/presto-main/src/test/java/com/facebook/presto/type/TestSimpleRowType.java index 242d99236f14d..ddd4cbde7c9ec 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestSimpleRowType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestSimpleRowType.java @@ -14,9 +14,9 @@ package com.facebook.presto.type; import com.facebook.presto.spi.block.ArrayBlockBuilder; -import com.facebook.presto.spi.block.ArrayElementBlockWriter; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.block.SingleArrayBlockWriter; import com.facebook.presto.spi.type.Type; import java.util.List; @@ -40,21 +40,21 @@ private static Block createTestBlock() { ArrayBlockBuilder blockBuilder = (ArrayBlockBuilder) TYPE.createBlockBuilder(new BlockBuilderStatus(), 3); - ArrayElementBlockWriter arrayElementBlockWriter; + SingleArrayBlockWriter singleArrayBlockWriter; - arrayElementBlockWriter = blockBuilder.beginBlockEntry(); - BIGINT.writeLong(arrayElementBlockWriter, 1); - VARCHAR.writeSlice(arrayElementBlockWriter, utf8Slice("cat")); + singleArrayBlockWriter = blockBuilder.beginBlockEntry(); + BIGINT.writeLong(singleArrayBlockWriter, 1); + VARCHAR.writeSlice(singleArrayBlockWriter, utf8Slice("cat")); blockBuilder.closeEntry(); - arrayElementBlockWriter = blockBuilder.beginBlockEntry(); - BIGINT.writeLong(arrayElementBlockWriter, 2); - VARCHAR.writeSlice(arrayElementBlockWriter, utf8Slice("cats")); + singleArrayBlockWriter = blockBuilder.beginBlockEntry(); + BIGINT.writeLong(singleArrayBlockWriter, 2); + VARCHAR.writeSlice(singleArrayBlockWriter, utf8Slice("cats")); blockBuilder.closeEntry(); - arrayElementBlockWriter = blockBuilder.beginBlockEntry(); - BIGINT.writeLong(arrayElementBlockWriter, 3); - VARCHAR.writeSlice(arrayElementBlockWriter, utf8Slice("dog")); + singleArrayBlockWriter = blockBuilder.beginBlockEntry(); + BIGINT.writeLong(singleArrayBlockWriter, 3); + VARCHAR.writeSlice(singleArrayBlockWriter, utf8Slice("dog")); blockBuilder.closeEntry(); return blockBuilder.build(); @@ -64,12 +64,12 @@ private static Block createTestBlock() protected Object getGreaterValue(Object value) { ArrayBlockBuilder blockBuilder = (ArrayBlockBuilder) TYPE.createBlockBuilder(new BlockBuilderStatus(), 1); - ArrayElementBlockWriter arrayElementBlockWriter; + SingleArrayBlockWriter singleArrayBlockWriter; Block block = (Block) value; - arrayElementBlockWriter = blockBuilder.beginBlockEntry(); - BIGINT.writeLong(arrayElementBlockWriter, block.getSingleValueBlock(0).getLong(0, 0) + 1); - VARCHAR.writeSlice(arrayElementBlockWriter, block.getSingleValueBlock(1).getSlice(0, 0, 1)); + singleArrayBlockWriter = blockBuilder.beginBlockEntry(); + BIGINT.writeLong(singleArrayBlockWriter, block.getSingleValueBlock(0).getLong(0, 0) + 1); + VARCHAR.writeSlice(singleArrayBlockWriter, block.getSingleValueBlock(1).getSlice(0, 0, 1)); blockBuilder.closeEntry(); return TYPE.getObject(blockBuilder.build(), 0); diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestSmallintVarcharMapType.java b/presto-main/src/test/java/com/facebook/presto/type/TestSmallintVarcharMapType.java index b5177ac887d83..86ef1b9f9aa6f 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestSmallintVarcharMapType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestSmallintVarcharMapType.java @@ -22,16 +22,16 @@ import java.util.Map; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestSmallintVarcharMapType extends AbstractTestType { public TestSmallintVarcharMapType() { - super(new TypeRegistry().getType(parseTypeSignature("map(smallint,varchar)")), Map.class, createTestBlock(new TypeRegistry().getType(parseTypeSignature("map(smallint,varchar)")))); + super(mapType(SMALLINT, VARCHAR), Map.class, createTestBlock(mapType(SMALLINT, VARCHAR))); } public static Block createTestBlock(Type mapType) diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestTinyintVarcharMapType.java b/presto-main/src/test/java/com/facebook/presto/type/TestTinyintVarcharMapType.java index b79a0edae2666..d0e3a44bcddb4 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestTinyintVarcharMapType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestTinyintVarcharMapType.java @@ -22,16 +22,16 @@ import java.util.Map; import static com.facebook.presto.spi.type.TinyintType.TINYINT; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; +import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestTinyintVarcharMapType extends AbstractTestType { public TestTinyintVarcharMapType() { - super(new TypeRegistry().getType(parseTypeSignature("map(tinyint,varchar)")), Map.class, createTestBlock(new TypeRegistry().getType(parseTypeSignature("map(tinyint,varchar)")))); + super(mapType(TINYINT, VARCHAR), Map.class, createTestBlock(mapType(TINYINT, VARCHAR))); } public static Block createTestBlock(Type mapType) diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java b/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java index b9bdab9a098b7..3a4db92c82216 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java @@ -59,7 +59,8 @@ public class TestTypeRegistry { - private final TypeRegistry typeRegistry = new TypeRegistry(); + private final TypeManager typeRegistry = new TypeRegistry(); + private final FunctionRegistry functionRegistry = new FunctionRegistry(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); @Test public void testNonexistentType() @@ -273,8 +274,6 @@ public void testCanCoerceIsTransitive() @Test public void testCastOperatorsExistForCoercions() { - FunctionRegistry functionRegistry = new FunctionRegistry(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); - Set types = getStandardPrimitiveTypes(); for (Type sourceType : types) { for (Type resultType : types) { diff --git a/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java b/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java index a38b7ef996e3c..00d0ed3bdb213 100644 --- a/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java +++ b/presto-main/src/test/java/com/facebook/presto/util/StructuralTestUtil.java @@ -13,11 +13,19 @@ */ package com.facebook.presto.util; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import java.util.Map; @@ -26,6 +34,12 @@ public final class StructuralTestUtil { + private static final TypeManager TYPE_MANAGER = new TypeRegistry(); + static { + // associate TYPE_MANAGER with a function registry + new FunctionRegistry(TYPE_MANAGER, new BlockEncodingManager(TYPE_MANAGER), new FeaturesConfig()); + } + private StructuralTestUtil() {} public static Block arrayBlockOf(Type elementType, Object... values) @@ -46,4 +60,11 @@ public static Block mapBlockOf(Type keyType, Type valueType, Map value) } return blockBuilder.build(); } + + public static MapType mapType(Type keyType, Type valueType) + { + return (MapType) TYPE_MANAGER.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/util/TestTimeZoneUtils.java b/presto-main/src/test/java/com/facebook/presto/util/TestTimeZoneUtils.java index 7094b06c70e7b..2d4d9ce8dd4c9 100644 --- a/presto-main/src/test/java/com/facebook/presto/util/TestTimeZoneUtils.java +++ b/presto-main/src/test/java/com/facebook/presto/util/TestTimeZoneUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.util; +import com.facebook.presto.server.JavaVersion; import com.facebook.presto.spi.type.TimeZoneKey; import com.google.common.collect.Sets; import org.joda.time.DateTime; @@ -45,6 +46,12 @@ public void test() if (zoneId.toLowerCase(ENGLISH).startsWith("etc/") || zoneId.toLowerCase(ENGLISH).startsWith("gmt")) { continue; } + // Known bug in Joda(https://github.com/JodaOrg/joda-time/issues/427) + // We will skip this timezone in test + if (JavaVersion.current().getMajor() == 8 && JavaVersion.current().getUpdate().orElse(0) < 121 && zoneId.equals("Asia/Rangoon")) { + continue; + } + DateTimeZone dateTimeZone = DateTimeZone.forID(zoneId); DateTimeZone indexedZone = getDateTimeZone(TimeZoneKey.getTimeZoneKey(zoneId)); diff --git a/presto-main/src/test/java/com/facebook/presto/util/maps/TestIdentityLinkedHashMap.java b/presto-main/src/test/java/com/facebook/presto/util/maps/TestIdentityLinkedHashMap.java deleted file mode 100644 index 7cbe13be63293..0000000000000 --- a/presto-main/src/test/java/com/facebook/presto/util/maps/TestIdentityLinkedHashMap.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.util.maps; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Maps; -import org.testng.annotations.Test; - -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.util.stream.IntStream.range; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertNotSame; -import static org.testng.Assert.assertTrue; - -public class TestIdentityLinkedHashMap -{ - @Test - public void testUsesIdentityAsEquivalenceForKeys() - { - String key = "foo"; - String otherKey = new String(key); - assertEquals(key, otherKey); - assertNotSame(key, otherKey); - - IdentityLinkedHashMap map = new IdentityLinkedHashMap<>(); - assertTrue(map.isEmpty()); - - int value = 1; - int otherValue = 2; - - map.put(key, value); - assertEquals(map, ImmutableMap.of(key, value)); - assertTrue(map.containsKey(key)); - assertFalse(map.containsKey(otherKey)); - - map.put(otherKey, otherValue); - assertEquals(map.get(key), Integer.valueOf(value)); - assertEquals(map.get(otherKey), Integer.valueOf(otherValue)); - - assertEquals(map.size(), otherValue); - map.remove(key); - assertEquals(map.size(), value); - map.remove(otherKey); - assertTrue(map.isEmpty()); - - Set keys = ImmutableSet.of("a", "aa"); - Map expectedMap = ImmutableMap.of("a", value, "aa", otherValue); - - map.putAll(Maps.asMap(keys, String::length)); - assertEquals(expectedMap, map); - - map.clear(); - assertTrue(map.isEmpty()); - } - - @Test - public void testStableIterationOrder() - { - List keys = ImmutableList.of("All", "your", "base", "are", "belong", "to", "us"); - List expectedValues = keys.stream().map(String::length).collect(toImmutableList()); - - range(0, 10).forEach(attempt -> { - IdentityLinkedHashMap map = new IdentityLinkedHashMap<>(); - - keys.forEach(i -> map.put(i, i.length())); - - assertEquals(ImmutableList.copyOf(map.keySet()), keys); - assertEquals(ImmutableList.copyOf(map.keySet().iterator()), keys); - assertEquals(map.keySet().stream().collect(toImmutableList()), keys); - - assertEquals(ImmutableList.copyOf(map.values()), expectedValues); - assertEquals(ImmutableList.copyOf(map.values().iterator()), expectedValues); - assertEquals(map.values().stream().collect(toImmutableList()), expectedValues); - - assertEquals(ImmutableList.copyOf(map.entrySet()).stream().map(Entry::getKey).collect(toImmutableList()), keys); - assertEquals(ImmutableList.copyOf(map.entrySet()::iterator).stream().map(Entry::getKey).collect(toImmutableList()), keys); - assertEquals(map.entrySet().stream().map(Entry::getKey).collect(toImmutableList()), keys); - }); - } -} diff --git a/presto-main/src/test/resources/catalog.json b/presto-main/src/test/resources/catalog.json new file mode 100644 index 0000000000000..275d5b6cf54b7 --- /dev/null +++ b/presto-main/src/test/resources/catalog.json @@ -0,0 +1,40 @@ +{ + "catalogs": [ + { + "user": "admin", + "catalog": ".*", + "allow": true + }, + { + "catalog": "secret", + "allow": false + }, + { + "user": ".*", + "catalog": "open-to-all", + "allow": true + }, + { + "catalog": "all-allowed", + "allow": true + }, + { + "user": "alice", + "catalog": "alice-catalog", + "allow": true + }, + { + "user": "bob", + "catalog": "alice-catalog", + "allow": false + }, + { + "catalog": "allowed-absent" + }, + { + "user": "\u0194\u0194\u0194", + "catalog": "\u0200\u0200\u0200", + "allow": true + } + ] +} diff --git a/presto-memory/pom.xml b/presto-memory/pom.xml index 39f53d2eef8e5..09142ef79f4a7 100644 --- a/presto-memory/pom.xml +++ b/presto-memory/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-memory diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryDataFragment.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryDataFragment.java new file mode 100644 index 0000000000000..cadb164b39cc2 --- /dev/null +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryDataFragment.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.memory; + +import com.facebook.presto.spi.HostAddress; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.airlift.json.JsonCodec; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.json.JsonCodec.jsonCodec; +import static java.util.Objects.requireNonNull; + +public class MemoryDataFragment +{ + private static final JsonCodec MEMORY_DATA_FRAGMENT_CODEC = jsonCodec(MemoryDataFragment.class); + + private final HostAddress hostAddress; + private final long rows; + + @JsonCreator + public MemoryDataFragment( + @JsonProperty("hostAddress") HostAddress hostAddress, + @JsonProperty("rows") long rows) + { + this.hostAddress = requireNonNull(hostAddress, "hostAddress is null"); + checkArgument(rows >= 0, "Rows number can not be negative"); + this.rows = rows; + } + + @JsonProperty + public HostAddress getHostAddress() + { + return hostAddress; + } + + @JsonProperty + public long getRows() + { + return rows; + } + + public Slice toSlice() + { + return Slices.wrappedBuffer(MEMORY_DATA_FRAGMENT_CODEC.toJsonBytes(this)); + } + + public static MemoryDataFragment fromSlice(Slice fragment) + { + return MEMORY_DATA_FRAGMENT_CODEC.fromJson(fragment.getBytes()); + } + + public static MemoryDataFragment merge(MemoryDataFragment a, MemoryDataFragment b) + { + checkArgument(a.getHostAddress().equals(b.getHostAddress()), "Can not merge fragments from different hosts"); + return new MemoryDataFragment(a.getHostAddress(), a.getRows() + b.getRows()); + } +} diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java index f3b6866ec0b22..d2858e07e4fd5 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java @@ -25,8 +25,10 @@ import com.facebook.presto.spi.ConnectorTableLayoutResult; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; import com.facebook.presto.spi.connector.ConnectorMetadata; @@ -39,18 +41,20 @@ import javax.annotation.concurrent.ThreadSafe; import javax.inject.Inject; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; -import java.util.stream.Collectors; +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; @@ -63,27 +67,39 @@ public class MemoryMetadata private final NodeManager nodeManager; private final String connectorId; + private final List schemas = new ArrayList<>(); private final AtomicLong nextTableId = new AtomicLong(); - private final Map tableIds = new ConcurrentHashMap<>(); - private final Map tables = new ConcurrentHashMap<>(); + private final Map tableIds = new HashMap<>(); + private final Map tables = new HashMap<>(); + private final Map> tableDataFragments = new HashMap<>(); @Inject public MemoryMetadata(NodeManager nodeManager, MemoryConnectorId connectorId) { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.connectorId = requireNonNull(connectorId, "connectorId is null").toString(); + this.schemas.add(SCHEMA_NAME); } @Override public synchronized List listSchemaNames(ConnectorSession session) { - return ImmutableList.of(SCHEMA_NAME); + return ImmutableList.copyOf(schemas); } @Override - public synchronized ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + public synchronized void createSchema(ConnectorSession session, String schemaName, Map properties) { - Long tableId = tableIds.get(tableName.getTableName()); + if (schemas.contains(schemaName)) { + throw new PrestoException(ALREADY_EXISTS, format("Schema [%s] already exists", schemaName)); + } + schemas.add(schemaName); + } + + @Override + public synchronized ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName schemaTableName) + { + Long tableId = tableIds.get(schemaTableName); if (tableId == null) { return null; } @@ -100,10 +116,8 @@ public synchronized ConnectorTableMetadata getTableMetadata(ConnectorSession ses @Override public synchronized List listTables(ConnectorSession session, String schemaNameOrNull) { - if (schemaNameOrNull != null && !schemaNameOrNull.equals(SCHEMA_NAME)) { - return ImmutableList.of(); - } return tables.values().stream() + .filter(table -> schemaNameOrNull == null || table.getSchemaName().equals(schemaNameOrNull)) .map(MemoryTableHandle::toSchemaTableName) .collect(toList()); } @@ -135,25 +149,26 @@ public synchronized Map> listTableColumns( public synchronized void dropTable(ConnectorSession session, ConnectorTableHandle tableHandle) { MemoryTableHandle handle = (MemoryTableHandle) tableHandle; - Long tableId = tableIds.remove(handle.getTableName()); + Long tableId = tableIds.remove(handle.toSchemaTableName()); if (tableId != null) { tables.remove(tableId); + tableDataFragments.remove(tableId); } } @Override public synchronized void renameTable(ConnectorSession session, ConnectorTableHandle tableHandle, SchemaTableName newTableName) { + checkTableNotExists(newTableName); MemoryTableHandle oldTableHandle = (MemoryTableHandle) tableHandle; MemoryTableHandle newTableHandle = new MemoryTableHandle( oldTableHandle.getConnectorId(), oldTableHandle.getSchemaName(), newTableName.getTableName(), oldTableHandle.getTableId(), - oldTableHandle.getColumnHandles(), - oldTableHandle.getHosts()); - tableIds.remove(oldTableHandle.getTableName()); - tableIds.put(newTableName.getTableName(), oldTableHandle.getTableId()); + oldTableHandle.getColumnHandles()); + tableIds.remove(oldTableHandle.toSchemaTableName()); + tableIds.put(newTableName, oldTableHandle.getTableId()); tables.remove(oldTableHandle.getTableId()); tables.put(oldTableHandle.getTableId(), newTableHandle); } @@ -168,24 +183,38 @@ public synchronized void createTable(ConnectorSession session, ConnectorTableMet @Override public synchronized MemoryOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional layout) { + checkTableNotExists(tableMetadata.getTable()); long nextId = nextTableId.getAndIncrement(); Set nodes = nodeManager.getRequiredWorkerNodes(); checkState(!nodes.isEmpty(), "No Memory nodes available"); - tableIds.put(tableMetadata.getTable().getTableName(), nextId); + tableIds.put(tableMetadata.getTable(), nextId); MemoryTableHandle table = new MemoryTableHandle( connectorId, nextId, - tableMetadata, - nodes.stream().map(Node::getHostAndPort).collect(Collectors.toList())); + tableMetadata); tables.put(table.getTableId(), table); + tableDataFragments.put(table.getTableId(), new HashMap<>()); return new MemoryOutputTableHandle(table, ImmutableSet.copyOf(tableIds.values())); } + private void checkTableNotExists(SchemaTableName tableName) + { + if (tables.values().stream() + .map(MemoryTableHandle::toSchemaTableName) + .anyMatch(tableName::equals)) { + throw new PrestoException(ALREADY_EXISTS, format("Table [%s] already exists", tableName.toString())); + } + } + @Override public synchronized Optional finishCreateTable(ConnectorSession session, ConnectorOutputTableHandle tableHandle, Collection fragments) { + requireNonNull(tableHandle, "tableHandle is null"); + MemoryOutputTableHandle memoryOutputHandle = (MemoryOutputTableHandle) tableHandle; + + updateRowsOnHosts(memoryOutputHandle.getTable(), fragments); return Optional.empty(); } @@ -199,9 +228,28 @@ public synchronized MemoryInsertTableHandle beginInsert(ConnectorSession session @Override public synchronized Optional finishInsert(ConnectorSession session, ConnectorInsertTableHandle insertHandle, Collection fragments) { + requireNonNull(insertHandle, "insertHandle is null"); + MemoryInsertTableHandle memoryInsertHandle = (MemoryInsertTableHandle) insertHandle; + + updateRowsOnHosts(memoryInsertHandle.getTable(), fragments); return Optional.empty(); } + private void updateRowsOnHosts(MemoryTableHandle table, Collection fragments) + { + checkState( + tableDataFragments.containsKey(table.getTableId()), + "Uninitialized table [%s.%s]", + table.getSchemaName(), + table.getTableName()); + Map dataFragments = tableDataFragments.get(table.getTableId()); + + for (Slice fragment : fragments) { + MemoryDataFragment memoryDataFragment = MemoryDataFragment.fromSlice(fragment); + dataFragments.merge(memoryDataFragment.getHostAddress(), memoryDataFragment, MemoryDataFragment::merge); + } + } + @Override public synchronized List getTableLayouts( ConnectorSession session, @@ -211,8 +259,17 @@ public synchronized List getTableLayouts( { requireNonNull(handle, "handle is null"); checkArgument(handle instanceof MemoryTableHandle); + MemoryTableHandle memoryTableHandle = (MemoryTableHandle) handle; + checkState( + tableDataFragments.containsKey(memoryTableHandle.getTableId()), + "Inconsistent state for the table [%s.%s]", + memoryTableHandle.getSchemaName(), + memoryTableHandle.getTableName()); + + List expectedFragments = ImmutableList.copyOf( + tableDataFragments.get(memoryTableHandle.getTableId()).values()); - MemoryTableLayoutHandle layoutHandle = new MemoryTableLayoutHandle((MemoryTableHandle) handle); + MemoryTableLayoutHandle layoutHandle = new MemoryTableLayoutHandle(memoryTableHandle, expectedFragments); return ImmutableList.of(new ConnectorTableLayoutResult(getTableLayout(session, layoutHandle), constraint.getSummary())); } diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java index 0066da1c0afe7..3be123c20dc01 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java @@ -17,9 +17,12 @@ import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; @@ -36,11 +39,19 @@ public class MemoryPageSinkProvider implements ConnectorPageSinkProvider { private final MemoryPagesStore pagesStore; + private final HostAddress currentHostAddress; @Inject - public MemoryPageSinkProvider(MemoryPagesStore pagesStore) + public MemoryPageSinkProvider(MemoryPagesStore pagesStore, NodeManager nodeManager) + { + this(pagesStore, requireNonNull(nodeManager, "nodeManager is null").getCurrentNode().getHostAndPort()); + } + + @VisibleForTesting + public MemoryPageSinkProvider(MemoryPagesStore pagesStore, HostAddress currentHostAddress) { this.pagesStore = requireNonNull(pagesStore, "pagesStore is null"); + this.currentHostAddress = requireNonNull(currentHostAddress, "currentHostAddress is null"); } @Override @@ -53,7 +64,7 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa pagesStore.cleanUp(memoryOutputTableHandle.getActiveTableIds()); pagesStore.initialize(tableId); - return new MemoryPageSink(pagesStore, tableId); + return new MemoryPageSink(pagesStore, currentHostAddress, tableId); } @Override @@ -65,18 +76,22 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa checkState(memoryInsertTableHandle.getActiveTableIds().contains(tableId)); pagesStore.cleanUp(memoryInsertTableHandle.getActiveTableIds()); - return new MemoryPageSink(pagesStore, tableId); + pagesStore.initialize(tableId); + return new MemoryPageSink(pagesStore, currentHostAddress, tableId); } private static class MemoryPageSink implements ConnectorPageSink { private final MemoryPagesStore pagesStore; + private final HostAddress currentHostAddress; private final long tableId; + private long addedRows; - public MemoryPageSink(MemoryPagesStore pagesStore, long tableId) + public MemoryPageSink(MemoryPagesStore pagesStore, HostAddress currentHostAddress, long tableId) { this.pagesStore = requireNonNull(pagesStore, "pagesStore is null"); + this.currentHostAddress = requireNonNull(currentHostAddress, "currentHostAddress is null"); this.tableId = tableId; } @@ -84,13 +99,14 @@ public MemoryPageSink(MemoryPagesStore pagesStore, long tableId) public CompletableFuture appendPage(Page page) { pagesStore.add(tableId, page); + addedRows += page.getPositionCount(); return NOT_BLOCKED; } @Override public CompletableFuture> finish() { - return completedFuture(ImmutableList.of()); + return completedFuture(ImmutableList.of(new MemoryDataFragment(currentHostAddress, addedRows).toSlice())); } @Override diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java index 1f3e28f617db4..b0c705b895407 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java @@ -51,11 +51,17 @@ public ConnectorPageSource createPageSource( long tableId = memorySplit.getTableHandle().getTableId(); int partNumber = memorySplit.getPartNumber(); int totalParts = memorySplit.getTotalPartsPerWorker(); + long expectedRows = memorySplit.getExpectedRows(); List columnIndexes = columns.stream() .map(MemoryColumnHandle.class::cast) .map(MemoryColumnHandle::getColumnIndex).collect(toList()); - List pages = pagesStore.getPages(tableId, partNumber, totalParts, columnIndexes); + List pages = pagesStore.getPages( + tableId, + partNumber, + totalParts, + columnIndexes, + expectedRows); return new FixedPageSource(pages); } diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java index fc5b07c006a90..e5d871bf08885 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java @@ -42,18 +42,18 @@ public class MemoryPagesStore @GuardedBy("this") private long currentBytes = 0; + private final Map tables = new HashMap<>(); + @Inject public MemoryPagesStore(MemoryConfig config) { this.maxBytes = config.getMaxDataPerNode().toBytes(); } - private final Map> pages = new HashMap<>(); - public synchronized void initialize(long tableId) { - if (!pages.containsKey(tableId)) { - pages.put(tableId, new ArrayList<>()); + if (!tables.containsKey(tableId)) { + tables.put(tableId, new TableData()); } } @@ -69,21 +69,30 @@ public synchronized void add(Long tableId, Page page) } currentBytes = newSize; - List tablePages = pages.get(tableId); - tablePages.add(page); + TableData tableData = tables.get(tableId); + tableData.add(page); } - public synchronized List getPages(Long tableId, int partNumber, int totalParts, List columnIndexes) + public synchronized List getPages( + Long tableId, + int partNumber, + int totalParts, + List columnIndexes, + long expectedRows) { if (!contains(tableId)) { throw new PrestoException(MISSING_DATA, "Failed to find table on a worker."); } + TableData tableData = tables.get(tableId); + if (tableData.getRows() < expectedRows) { + throw new PrestoException(MISSING_DATA, + format("Expected to find [%s] rows on a worker, but found [%s].", expectedRows, tableData.getRows())); + } - List tablePages = pages.get(tableId); ImmutableList.Builder partitionedPages = ImmutableList.builder(); - for (int i = partNumber; i < tablePages.size(); i += totalParts) { - partitionedPages.add(getColumns(tablePages.get(i), columnIndexes)); + for (int i = partNumber; i < tableData.getPages().size(); i += totalParts) { + partitionedPages.add(getColumns(tableData.getPages().get(i), columnIndexes)); } return partitionedPages.build(); @@ -91,7 +100,7 @@ public synchronized List getPages(Long tableId, int partNumber, int totalP public synchronized boolean contains(Long tableId) { - return pages.containsKey(tableId); + return tables.containsKey(tableId); } public synchronized void cleanUp(Set activeTableIds) @@ -110,14 +119,14 @@ public synchronized void cleanUp(Set activeTableIds) } long latestTableId = Collections.max(activeTableIds); - for (Iterator>> tablePages = pages.entrySet().iterator(); tablePages.hasNext(); ) { - Map.Entry> tablePagesEntry = tablePages.next(); + for (Iterator> tableDataIterator = tables.entrySet().iterator(); tableDataIterator.hasNext(); ) { + Map.Entry tablePagesEntry = tableDataIterator.next(); Long tableId = tablePagesEntry.getKey(); if (tableId < latestTableId && !activeTableIds.contains(tableId)) { - for (Page removedPage : tablePagesEntry.getValue()) { + for (Page removedPage : tablePagesEntry.getValue().getPages()) { currentBytes -= removedPage.getRetainedSizeInBytes(); } - tablePages.remove(); + tableDataIterator.remove(); } } } @@ -133,4 +142,26 @@ private static Page getColumns(Page page, List columnIndexes) return new Page(page.getPositionCount(), outputBlocks); } + + private static final class TableData + { + private List pages = new ArrayList<>(); + private long rows; + + public void add(Page page) + { + pages.add(page); + rows += page.getPositionCount(); + } + + private List getPages() + { + return pages; + } + + private long getRows() + { + return rows; + } + } } diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplit.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplit.java index a0694b924ffea..45fa1cecef73e 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplit.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplit.java @@ -32,14 +32,16 @@ public class MemorySplit private final MemoryTableHandle tableHandle; private final int totalPartsPerWorker; // how many concurrent reads there will be from one worker private final int partNumber; // part of the pages on one worker that this splits is responsible - private final List addresses; + private final HostAddress address; + private final long expectedRows; @JsonCreator public MemorySplit( @JsonProperty("tableHandle") MemoryTableHandle tableHandle, @JsonProperty("partNumber") int partNumber, @JsonProperty("totalPartsPerWorker") int totalPartsPerWorker, - @JsonProperty("addresses") List addresses) + @JsonProperty("address") HostAddress address, + @JsonProperty("expectedRows") long expectedRows) { checkState(partNumber >= 0, "partNumber must be >= 0"); checkState(totalPartsPerWorker >= 1, "totalPartsPerWorker must be >= 1"); @@ -48,7 +50,8 @@ public MemorySplit( this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); this.partNumber = partNumber; this.totalPartsPerWorker = totalPartsPerWorker; - this.addresses = ImmutableList.copyOf(requireNonNull(addresses, "addresses is null")); + this.address = requireNonNull(address, "address is null"); + this.expectedRows = expectedRows; } @JsonProperty @@ -82,10 +85,21 @@ public boolean isRemotelyAccessible() } @JsonProperty + public HostAddress getAddress() + { + return address; + } + @Override public List getAddresses() { - return addresses; + return ImmutableList.of(address); + } + + @JsonProperty + public long getExpectedRows() + { + return expectedRows; } @Override diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java index 47d86692a928f..4e8b7dcedf422 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java @@ -18,7 +18,6 @@ import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.FixedSplitSource; -import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; @@ -43,17 +42,18 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand { MemoryTableLayoutHandle layout = (MemoryTableLayoutHandle) layoutHandle; - List hosts = layout.getTable().getHosts(); + List dataFragments = layout.getDataFragments(); ImmutableList.Builder splits = ImmutableList.builder(); - for (HostAddress host : hosts) { + for (MemoryDataFragment dataFragment : dataFragments) { for (int i = 0; i < splitsPerNode; i++) { splits.add( new MemorySplit( layout.getTable(), i, splitsPerNode, - ImmutableList.of(host))); + dataFragment.getHostAddress(), + dataFragment.getRows())); } } return new FixedSplitSource(splits.build()); diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableHandle.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableHandle.java index 8765e3985ace8..cd18ac9e8353d 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableHandle.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableHandle.java @@ -15,7 +15,6 @@ import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableMetadata; -import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.SchemaTableName; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -35,20 +34,17 @@ public final class MemoryTableHandle private final String tableName; private final Long tableId; private final List columnHandles; - private final List hosts; public MemoryTableHandle( String connectorId, Long tableId, - ConnectorTableMetadata tableMetadata, - List hosts) + ConnectorTableMetadata tableMetadata) { this(connectorId, tableMetadata.getTable().getSchemaName(), tableMetadata.getTable().getTableName(), tableId, - MemoryColumnHandle.extractColumnHandles(tableMetadata.getColumns()), - hosts); + MemoryColumnHandle.extractColumnHandles(tableMetadata.getColumns())); } @JsonCreator @@ -57,15 +53,13 @@ public MemoryTableHandle( @JsonProperty("schemaName") String schemaName, @JsonProperty("tableName") String tableName, @JsonProperty("tableId") Long tableId, - @JsonProperty("columnHandles") List columnHandles, - @JsonProperty("hosts") List hosts) + @JsonProperty("columnHandles") List columnHandles) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.schemaName = requireNonNull(schemaName, "schemaName is null"); this.tableName = requireNonNull(tableName, "tableName is null"); this.tableId = requireNonNull(tableId, "tableId is null"); this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); - this.hosts = requireNonNull(hosts, "hosts is null"); } @JsonProperty @@ -98,12 +92,6 @@ public List getColumnHandles() return columnHandles; } - @JsonProperty - public List getHosts() - { - return hosts; - } - public ConnectorTableMetadata toTableMetadata() { return new ConnectorTableMetadata( diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableLayoutHandle.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableLayoutHandle.java index bc823844619fe..9ba0768c7777e 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableLayoutHandle.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryTableLayoutHandle.java @@ -17,17 +17,23 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + import static java.util.Objects.requireNonNull; public class MemoryTableLayoutHandle implements ConnectorTableLayoutHandle { private final MemoryTableHandle table; + private final List dataFragments; @JsonCreator - public MemoryTableLayoutHandle(@JsonProperty("table") MemoryTableHandle table) + public MemoryTableLayoutHandle( + @JsonProperty("table") MemoryTableHandle table, + @JsonProperty("dataFragments") List dataFragments) { this.table = requireNonNull(table, "table is null"); + this.dataFragments = requireNonNull(dataFragments, "dataFragments is null"); } @JsonProperty @@ -36,6 +42,12 @@ public MemoryTableHandle getTable() return table; } + @JsonProperty + public List getDataFragments() + { + return dataFragments; + } + public String getConnectorId() { return table.getConnectorId(); diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/MemoryQueryRunner.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/MemoryQueryRunner.java index b39cc47161e47..fc23d80e21174 100644 --- a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/MemoryQueryRunner.java +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/MemoryQueryRunner.java @@ -30,6 +30,8 @@ public final class MemoryQueryRunner { + public static final String CATALOG = "memory"; + private MemoryQueryRunner() {} public static DistributedQueryRunner createQueryRunner() @@ -42,7 +44,7 @@ public static DistributedQueryRunner createQueryRunner(Map extra throws Exception { Session session = testSessionBuilder() - .setCatalog("memory") + .setCatalog(CATALOG) .setSchema("default") .build(); @@ -50,7 +52,7 @@ public static DistributedQueryRunner createQueryRunner(Map extra try { queryRunner.installPlugin(new MemoryPlugin()); - queryRunner.createCatalog("memory", "memory", ImmutableMap.of()); + queryRunner.createCatalog(CATALOG, "memory", ImmutableMap.of()); queryRunner.installPlugin(new TpchPlugin()); queryRunner.createCatalog("tpch", "tpch", ImmutableMap.of()); diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java index 15a4a1897fef4..3439780b4fbba 100644 --- a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java @@ -14,7 +14,13 @@ package com.facebook.presto.plugin.memory; import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.testing.TestingNodeManager; import com.google.common.collect.ImmutableList; @@ -25,10 +31,12 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; @Test(singleThreaded = true) public class TestMemoryMetadata @@ -44,7 +52,7 @@ public void setUp() @Test public void tableIsCreatedAfterCommits() { - assertThatNoTableIsCreated(); + assertNoTables(); SchemaTableName schemaTableName = new SchemaTableName("default", "temp_table"); @@ -56,14 +64,45 @@ public void tableIsCreatedAfterCommits() metadata.finishCreateTable(SESSION, table, ImmutableList.of()); List tables = metadata.listTables(SESSION, null); - assertTrue(tables.size() == 1, "Expected only one table."); + assertTrue(tables.size() == 1, "Expected only one table"); assertTrue(tables.get(0).getTableName().equals("temp_table"), "Expected table with name 'temp_table'"); } + @Test + public void tableAlreadyExists() + { + assertNoTables(); + + SchemaTableName test1Table = new SchemaTableName("default", "test1"); + SchemaTableName test2Table = new SchemaTableName("default", "test2"); + metadata.createTable(SESSION, new ConnectorTableMetadata(test1Table, ImmutableList.of())); + + try { + metadata.createTable(SESSION, new ConnectorTableMetadata(test1Table, ImmutableList.of())); + fail("Should fail because table already exists"); + } + catch (PrestoException ex) { + assertEquals(ex.getErrorCode(), ALREADY_EXISTS.toErrorCode()); + assertEquals(ex.getMessage(), "Table [default.test1] already exists"); + } + + ConnectorTableHandle test1TableHandle = metadata.getTableHandle(SESSION, test1Table); + metadata.createTable(SESSION, new ConnectorTableMetadata(test2Table, ImmutableList.of())); + + try { + metadata.renameTable(SESSION, test1TableHandle, test2Table); + fail("Should fail because table already exists"); + } + catch (PrestoException ex) { + assertEquals(ex.getErrorCode(), ALREADY_EXISTS.toErrorCode()); + assertEquals(ex.getMessage(), "Table [default.test2] already exists"); + } + } + @Test public void testActiveTableIds() { - assertThatNoTableIsCreated(); + assertNoTables(); SchemaTableName firstTableName = new SchemaTableName("default", "first_table"); metadata.createTable(SESSION, new ConnectorTableMetadata(firstTableName, ImmutableList.of(), ImmutableMap.of())); @@ -84,7 +123,54 @@ public void testActiveTableIds() assertTrue(metadata.beginInsert(SESSION, secondTableHandle).getActiveTableIds().contains(secondTableId)); } - private void assertThatNoTableIsCreated() + @Test + public void testReadTableBeforeCreationCompleted() + { + assertNoTables(); + + SchemaTableName tableName = new SchemaTableName("default", "temp_table"); + + ConnectorOutputTableHandle table = metadata.beginCreateTable( + SESSION, + new ConnectorTableMetadata(tableName, ImmutableList.of(), ImmutableMap.of()), + Optional.empty()); + + List tableNames = metadata.listTables(SESSION, null); + assertTrue(tableNames.size() == 1, "Expected exactly one table"); + + ConnectorTableHandle tableHandle = metadata.getTableHandle(SESSION, tableName); + List tableLayouts = metadata.getTableLayouts(SESSION, tableHandle, Constraint.alwaysTrue(), Optional.empty()); + assertTrue(tableLayouts.size() == 1, "Expected exactly one layout."); + ConnectorTableLayout tableLayout = tableLayouts.get(0).getTableLayout(); + ConnectorTableLayoutHandle tableLayoutHandle = tableLayout.getHandle(); + assertTrue(tableLayoutHandle instanceof MemoryTableLayoutHandle); + assertTrue(((MemoryTableLayoutHandle) tableLayoutHandle).getDataFragments().isEmpty(), "Data fragments should be empty"); + + metadata.finishCreateTable(SESSION, table, ImmutableList.of()); + } + + @Test + public void testCreateSchema() + { + assertEquals(metadata.listSchemaNames(SESSION), ImmutableList.of("default")); + metadata.createSchema(SESSION, "test", ImmutableMap.of()); + assertEquals(metadata.listSchemaNames(SESSION), ImmutableList.of("default", "test")); + assertEquals(metadata.listTables(SESSION, "test"), ImmutableList.of()); + + SchemaTableName tableName = new SchemaTableName("test", "first_table"); + metadata.createTable( + SESSION, + new ConnectorTableMetadata( + tableName, + ImmutableList.of(), + ImmutableMap.of())); + + assertEquals(metadata.listTables(SESSION, null), ImmutableList.of(tableName)); + assertEquals(metadata.listTables(SESSION, "test"), ImmutableList.of(tableName)); + assertEquals(metadata.listTables(SESSION, "default"), ImmutableList.of()); + } + + private void assertNoTables() { assertEquals(metadata.listTables(SESSION, null), ImmutableList.of(), "No table was expected"); } diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java index 7e5b827b48672..f65cf466c9310 100644 --- a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java @@ -38,6 +38,7 @@ public class TestMemoryPagesStore { public static final ConnectorSession SESSION = new TestingConnectorSession(ImmutableList.of()); + private static final int POSITIONS_PER_PAGE = 0; private MemoryPagesStore pagesStore; private MemoryPageSinkProvider pageSinkProvider; @@ -46,14 +47,14 @@ public class TestMemoryPagesStore public void setUp() { pagesStore = new MemoryPagesStore(new MemoryConfig().setMaxDataPerNode(new DataSize(1, DataSize.Unit.MEGABYTE))); - pageSinkProvider = new MemoryPageSinkProvider(pagesStore); + pageSinkProvider = new MemoryPageSinkProvider(pagesStore, HostAddress.fromString("localhost:8080")); } @Test public void testCreateEmptyTable() { createTable(0L, 0L); - assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0)), ImmutableList.of()); + assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0), ImmutableList.of()); } @Test @@ -61,19 +62,28 @@ public void testInsertPage() { createTable(0L, 0L); insertToTable(0L, 0L); - assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0)).size(), 1); + assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), POSITIONS_PER_PAGE).size(), 1); + } + + @Test + public void testInsertPageWithoutCreate() + { + insertToTable(0L, 0L); + assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), POSITIONS_PER_PAGE).size(), 1); } @Test(expectedExceptions = PrestoException.class) public void testReadFromUnknownTable() { - pagesStore.getPages(0L, 0, 1, ImmutableList.of(0)); + pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0); } @Test(expectedExceptions = PrestoException.class) - public void testWriteToUnknownTable() + public void testTryToReadFromEmptyTable() { - insertToTable(0L, 0L); + createTable(0L, 0L); + assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0), ImmutableList.of()); + pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 42); } @Test @@ -139,8 +149,8 @@ private static ConnectorOutputTableHandle createMemoryOutputTableHandle(long tab "test", "schema", format("table_%d", tableId), - tableId, ImmutableList.of(), - ImmutableList.of(HostAddress.fromString("localhost:8080"))), + tableId, + ImmutableList.of()), ImmutableSet.copyOf(activeTableIds)); } @@ -152,21 +162,20 @@ private static ConnectorInsertTableHandle createMemoryInsertTableHandle(long tab "schema", format("table_%d", tableId), tableId, - ImmutableList.of(), - ImmutableList.of(HostAddress.fromString("localhost:8080"))), + ImmutableList.of()), ImmutableSet.copyOf(activeTableIds)); } private static Page createPage() { - BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(1); + BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(POSITIONS_PER_PAGE); BIGINT.writeLong(blockBuilder, 42L); return new Page(0, blockBuilder.build()); } private static Page createOneMegaBytePage() { - BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(1); + BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(POSITIONS_PER_PAGE); while (blockBuilder.getRetainedSizeInBytes() < 1024 * 1024) { BIGINT.writeLong(blockBuilder, 42L); } diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemorySmoke.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemorySmoke.java index 3f1776ff831ad..1ee2d351df45d 100644 --- a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemorySmoke.java +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemorySmoke.java @@ -13,23 +13,21 @@ */ package com.facebook.presto.plugin.memory; -import com.facebook.presto.Session; import com.facebook.presto.metadata.QualifiedObjectName; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.QueryRunner; -import com.google.common.collect.Iterables; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; import java.sql.SQLException; import java.util.List; +import static com.facebook.presto.plugin.memory.MemoryQueryRunner.CATALOG; import static com.facebook.presto.plugin.memory.MemoryQueryRunner.createQueryRunner; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static java.lang.String.format; import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; @Test(singleThreaded = true) public class TestMemorySmoke @@ -44,98 +42,122 @@ public void setUp() } @Test - public void createAndDropTable() + public void testCreateAndDropTable() throws SQLException { int tablesBeforeCreate = listMemoryTables().size(); - queryRunner.execute("CREATE TABLE test as SELECT * FROM tpch.tiny.nation"); + queryRunner.execute("CREATE TABLE test AS SELECT * FROM tpch.tiny.nation"); assertEquals(listMemoryTables().size(), tablesBeforeCreate + 1); queryRunner.execute(format("DROP TABLE test")); assertEquals(listMemoryTables().size(), tablesBeforeCreate); } - @Test - public void createTableWhenTableIsAlreadyCreated() + // it has to be RuntimeException as FailureInfo$FailureException is private + @Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = "line 1:1: Destination table 'memory.default.nation' already exists") + public void testCreateTableWhenTableIsAlreadyCreated() throws SQLException { - String createTableSql = "CREATE TABLE nation as SELECT * FROM tpch.tiny.nation"; - try { - queryRunner.execute(createTableSql); - fail("Expected exception to be thrown here!"); - } - catch (RuntimeException ex) { // it has to RuntimeException as FailureInfo$FailureException is private - assertTrue(ex.getMessage().equals("line 1:1: Destination table 'memory.default.nation' already exists")); - } + String createTableSql = "CREATE TABLE nation AS SELECT * FROM tpch.tiny.nation"; + queryRunner.execute(createTableSql); } @Test - public void select() + public void testSelect() throws SQLException { - queryRunner.execute("CREATE TABLE test_select as SELECT * FROM tpch.tiny.nation"); + queryRunner.execute("CREATE TABLE test_select AS SELECT * FROM tpch.tiny.nation"); - assertThatQueryReturnsSameValueAs("SELECT * FROM test_select ORDER BY nationkey", "SELECT * FROM tpch.tiny.nation ORDER BY nationkey"); + assertQuery("SELECT * FROM test_select ORDER BY nationkey", "SELECT * FROM tpch.tiny.nation ORDER BY nationkey"); - assertThatQueryReturnsValue("INSERT INTO test_select SELECT * FROM tpch.tiny.nation", 25L); + assertQueryResult("INSERT INTO test_select SELECT * FROM tpch.tiny.nation", 25L); - assertThatQueryReturnsValue("INSERT INTO test_select SELECT * FROM tpch.tiny.nation", 25L); + assertQueryResult("INSERT INTO test_select SELECT * FROM tpch.tiny.nation", 25L); - assertThatQueryReturnsValue("SELECT count(*) FROM test_select", 75L); + assertQueryResult("SELECT count(*) FROM test_select", 75L); } @Test - public void selectFromEmptyTable() + public void testCreateTableWithNoData() throws SQLException { - queryRunner.execute("CREATE TABLE test_select_empty as SELECT * FROM tpch.tiny.nation WHERE nationkey > 1000"); - - assertThatQueryReturnsValue("SELECT count(*) FROM test_select_empty", 0L); + queryRunner.execute("CREATE TABLE test_empty (a BIGINT)"); + assertQueryResult("SELECT count(*) FROM test_empty", 0L); + assertQueryResult("INSERT INTO test_empty SELECT nationkey FROM tpch.tiny.nation", 25L); + assertQueryResult("SELECT count(*) FROM test_empty", 25L); } @Test - public void selectSingleRow() + public void testCreateFilteredOutTable() + throws SQLException { - assertThatQueryReturnsSameValueAs("SELECT * FROM nation WHERE nationkey = 1", "SELECT * FROM tpch.tiny.nation WHERE nationkey = 1"); + queryRunner.execute("CREATE TABLE filtered_out AS SELECT nationkey FROM tpch.tiny.nation WHERE nationkey < 0"); + assertQueryResult("SELECT count(*) FROM filtered_out", 0L); + assertQueryResult("INSERT INTO filtered_out SELECT nationkey FROM tpch.tiny.nation", 25L); + assertQueryResult("SELECT count(*) FROM filtered_out", 25L); } @Test - public void selectColumnsSubset() + public void testSelectFromEmptyTable() throws SQLException { - assertThatQueryReturnsSameValueAs("SELECT nationkey, regionkey FROM nation ORDER BY nationkey", "SELECT nationkey, regionkey FROM tpch.tiny.nation ORDER BY nationkey"); + queryRunner.execute("CREATE TABLE test_select_empty AS SELECT * FROM tpch.tiny.nation WHERE nationkey > 1000"); + + assertQueryResult("SELECT count(*) FROM test_select_empty", 0L); } - private List listMemoryTables() + @Test + public void testSelectSingleRow() { - return queryRunner.listTables(queryRunner.getDefaultSession(), "memory", "default"); + assertQuery("SELECT * FROM nation WHERE nationkey = 1", "SELECT * FROM tpch.tiny.nation WHERE nationkey = 1"); } - private void assertThatQueryReturnsValue(String sql, Object expected) + @Test + public void testSelectColumnsSubset() + throws SQLException { - assertThatQueryReturnsValue(sql, expected, null); + assertQuery("SELECT nationkey, regionkey FROM nation ORDER BY nationkey", "SELECT nationkey, regionkey FROM tpch.tiny.nation ORDER BY nationkey"); } - private void assertThatQueryReturnsValue(String sql, Object expected, Session session) + @Test + public void testCreateTableInNonDefaultSchema() + { + queryRunner.execute(format("CREATE SCHEMA %s.schema1", CATALOG)); + queryRunner.execute(format("CREATE SCHEMA %s.schema2", CATALOG)); + + assertQueryResult(format("SHOW SCHEMAS FROM %s", CATALOG), "default", "information_schema", "schema1", "schema2"); + + queryRunner.execute(format("CREATE TABLE %s.schema1.nation AS SELECT * FROM tpch.tiny.nation WHERE nationkey %% 2 = 0", CATALOG)); + queryRunner.execute(format("CREATE TABLE %s.schema2.nation AS SELECT * FROM tpch.tiny.nation WHERE nationkey %% 2 = 1", CATALOG)); + + assertQueryResult(format("SELECT count(*) FROM %s.schema1.nation", CATALOG), 13L); + assertQueryResult(format("SELECT count(*) FROM %s.schema2.nation", CATALOG), 12L); + } + + private List listMemoryTables() { - MaterializedResult rows = session == null ? queryRunner.execute(sql) : queryRunner.execute(session, sql); - MaterializedRow materializedRow = Iterables.getOnlyElement(rows); - int fieldCount = materializedRow.getFieldCount(); - assertTrue(fieldCount == 1, format("Expected only one column, but got '%d'", fieldCount)); - Object value = materializedRow.getField(0); - assertEquals(value, expected); - assertTrue(Iterables.getOnlyElement(rows).getFieldCount() == 1); + return queryRunner.listTables(queryRunner.getDefaultSession(), "memory", "default"); } - private void assertThatQueryReturnsSameValueAs(String sql, String compareSql) + private void assertQueryResult(String sql, Object... expected) { - assertThatQueryReturnsSameValueAs(sql, compareSql, null); + MaterializedResult rows = queryRunner.execute(sql); + assertEquals(rows.getRowCount(), expected.length); + + for (int i = 0; i < expected.length; i++) { + MaterializedRow materializedRow = rows.getMaterializedRows().get(i); + int fieldCount = materializedRow.getFieldCount(); + assertTrue(fieldCount == 1, format("Expected only one column, but got '%d'", fieldCount)); + Object value = materializedRow.getField(0); + assertEquals(value, expected[i]); + assertTrue(materializedRow.getFieldCount() == 1); + } } - private void assertThatQueryReturnsSameValueAs(String sql, String compareSql, Session session) + private void assertQuery(String sql, String expected) { - MaterializedResult rows = session == null ? queryRunner.execute(sql) : queryRunner.execute(session, sql); - MaterializedResult expectedRows = session == null ? queryRunner.execute(compareSql) : queryRunner.execute(session, compareSql); + MaterializedResult rows = queryRunner.execute(sql); + MaterializedResult expectedRows = queryRunner.execute(expected); assertEquals(rows, expectedRows); } diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryWorkerCrash.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryWorkerCrash.java new file mode 100644 index 0000000000000..3bec61b8cddd7 --- /dev/null +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryWorkerCrash.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.memory; + +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.tests.DistributedQueryRunner; +import io.airlift.units.Duration; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static com.facebook.presto.plugin.memory.MemoryQueryRunner.createQueryRunner; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static io.airlift.testing.Assertions.assertLessThan; +import static io.airlift.units.Duration.nanosSince; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.fail; + +@Test(singleThreaded = true) +public class TestMemoryWorkerCrash +{ + private DistributedQueryRunner queryRunner; + + @BeforeMethod + public void setUp() + throws Exception + { + queryRunner = createQueryRunner(); + } + + @Test + public void tableAccessAfterWorkerCrash() + throws Exception + { + queryRunner.execute("CREATE TABLE test_nation as SELECT * FROM tpch.tiny.nation"); + assertQuery("SELECT * FROM test_nation ORDER BY nationkey", "SELECT * FROM tpch.tiny.nation ORDER BY nationkey"); + closeWorker(); + assertFails("SELECT * FROM test_nation ORDER BY nationkey", "No nodes available to run query"); + queryRunner.execute("INSERT INTO test_nation SELECT * FROM tpch.tiny.nation"); + assertFails("SELECT * FROM test_nation ORDER BY nationkey", "No nodes available to run query"); + + queryRunner.execute("CREATE TABLE test_region as SELECT * FROM tpch.tiny.region"); + assertQuery("SELECT * FROM test_region ORDER BY regionkey", "SELECT * FROM tpch.tiny.region ORDER BY regionkey"); + } + + private void closeWorker() + throws Exception + { + int nodeCount = queryRunner.getNodeCount(); + TestingPrestoServer worker = queryRunner.getServers().stream() + .filter(server -> !server.isCoordinator()) + .findAny() + .orElseThrow(() -> new IllegalStateException("No worker nodes")); + worker.close(); + waitForNodes(nodeCount - 1); + } + + private void assertQuery(String sql, String expected) + { + MaterializedResult rows = queryRunner.execute(sql); + MaterializedResult expectedRows = queryRunner.execute(expected); + + assertEquals(rows, expectedRows); + } + + private void assertFails(String sql, String expectedMessage) + { + try { + queryRunner.execute(sql); + } + catch (RuntimeException ex) { + // pass + assertEquals(ex.getMessage(), expectedMessage); + return; + } + fail("Query should fail"); + } + + private void waitForNodes(int numberOfNodes) + throws InterruptedException + { + long start = System.nanoTime(); + while (queryRunner.getCoordinator().refreshNodes().getActiveNodes().size() < numberOfNodes) { + assertLessThan(nanosSince(start), new Duration(10, SECONDS)); + MILLISECONDS.sleep(10); + } + } +} diff --git a/presto-ml/pom.xml b/presto-ml/pom.xml index c7f00cd32feba..7ee92d9d891d3 100644 --- a/presto-ml/pom.xml +++ b/presto-ml/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-ml diff --git a/presto-ml/src/main/java/com/facebook/presto/ml/MLFeaturesFunctions.java b/presto-ml/src/main/java/com/facebook/presto/ml/MLFeaturesFunctions.java new file mode 100644 index 0000000000000..aa69b38b14cb8 --- /dev/null +++ b/presto-ml/src/main/java/com/facebook/presto/ml/MLFeaturesFunctions.java @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.ml; + +import com.facebook.presto.spi.PageBuilder; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.DoubleType; +import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.collect.ImmutableList; +import com.google.common.hash.HashCode; + +public final class MLFeaturesFunctions +{ + private static final Cache MODEL_CACHE = CacheBuilder.newBuilder().maximumSize(5).build(); + private static final String MAP_BIGINT_DOUBLE = "map(bigint,double)"; + + private final PageBuilder pageBuilder; + + public MLFeaturesFunctions(@TypeParameter("map(bigint,double)") Type mapType) + { + pageBuilder = new PageBuilder(ImmutableList.of(mapType)); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1) + { + return featuresHelper(f1); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2) + { + return featuresHelper(f1, f2); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3) + { + return featuresHelper(f1, f2, f3); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4) + { + return featuresHelper(f1, f2, f3, f4); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5) + { + return featuresHelper(f1, f2, f3, f4, f5); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6) + { + return featuresHelper(f1, f2, f3, f4, f5, f6); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9); + } + + @ScalarFunction + @SqlType(MAP_BIGINT_DOUBLE) + public Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9, @SqlType(StandardTypes.DOUBLE) double f10) + { + return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9, f10); + } + + private Block featuresHelper(double... features) + { + BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); + + for (int i = 0; i < features.length; i++) { + BigintType.BIGINT.writeLong(blockBuilder, i); + DoubleType.DOUBLE.writeDouble(blockBuilder, features[i]); + } + + mapBlockBuilder.closeEntry(); + return mapBlockBuilder.getObject(mapBlockBuilder.getPositionCount() - 1, Block.class); + } +} diff --git a/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java b/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java index 518534b6b1fdc..e14f6dad3ba5a 100644 --- a/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java +++ b/presto-ml/src/main/java/com/facebook/presto/ml/MLFunctions.java @@ -15,17 +15,11 @@ import com.facebook.presto.ml.type.RegressorType; import com.facebook.presto.spi.block.Block; -import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.spi.type.BigintType; -import com.facebook.presto.spi.type.DoubleType; import com.facebook.presto.spi.type.StandardTypes; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -import com.google.common.collect.ImmutableList; import com.google.common.hash.HashCode; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -89,86 +83,4 @@ private static Model getOrLoadModel(Slice slice) return model; } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1) - { - return featuresHelper(f1); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2) - { - return featuresHelper(f1, f2); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3) - { - return featuresHelper(f1, f2, f3); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4) - { - return featuresHelper(f1, f2, f3, f4); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5) - { - return featuresHelper(f1, f2, f3, f4, f5); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6) - { - return featuresHelper(f1, f2, f3, f4, f5, f6); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9); - } - - @ScalarFunction - @SqlType(MAP_BIGINT_DOUBLE) - public static Block features(@SqlType(StandardTypes.DOUBLE) double f1, @SqlType(StandardTypes.DOUBLE) double f2, @SqlType(StandardTypes.DOUBLE) double f3, @SqlType(StandardTypes.DOUBLE) double f4, @SqlType(StandardTypes.DOUBLE) double f5, @SqlType(StandardTypes.DOUBLE) double f6, @SqlType(StandardTypes.DOUBLE) double f7, @SqlType(StandardTypes.DOUBLE) double f8, @SqlType(StandardTypes.DOUBLE) double f9, @SqlType(StandardTypes.DOUBLE) double f10) - { - return featuresHelper(f1, f2, f3, f4, f5, f6, f7, f8, f9, f10); - } - - private static Block featuresHelper(double... features) - { - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(BigintType.BIGINT, DoubleType.DOUBLE), new BlockBuilderStatus(), features.length); - - for (int i = 0; i < features.length; i++) { - BigintType.BIGINT.writeLong(blockBuilder, i); - DoubleType.DOUBLE.writeDouble(blockBuilder, features[i]); - } - - return blockBuilder.build(); - } } diff --git a/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java b/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java index 7cbfa86aef7cc..c7d96bcf374c5 100644 --- a/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java +++ b/presto-ml/src/main/java/com/facebook/presto/ml/MLPlugin.java @@ -52,6 +52,7 @@ public Set> getFunctions() .add(LearnLibSvmRegressorAggregation.class) .add(EvaluateClassifierPredictionsAggregation.class) .add(MLFunctions.class) + .add(MLFeaturesFunctions.class) .build(); } } diff --git a/presto-ml/src/main/java/com/facebook/presto/ml/type/ClassifierParametricType.java b/presto-ml/src/main/java/com/facebook/presto/ml/type/ClassifierParametricType.java index 234af99ba1115..b50f61ab0f780 100644 --- a/presto-ml/src/main/java/com/facebook/presto/ml/type/ClassifierParametricType.java +++ b/presto-ml/src/main/java/com/facebook/presto/ml/type/ClassifierParametricType.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.type.ParameterKind; import com.facebook.presto.spi.type.ParametricType; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeParameter; import java.util.List; @@ -34,7 +35,7 @@ public String getName() } @Override - public Type createType(List parameters) + public Type createType(TypeManager typeManager, List parameters) { checkArgument(parameters.size() == 1, "Expected only one type, got %s", parameters); checkArgument( diff --git a/presto-ml/src/test/java/com/facebook/presto/ml/TestLearnAggregations.java b/presto-ml/src/test/java/com/facebook/presto/ml/TestLearnAggregations.java index 8498404186ccc..4009772c2cc3c 100644 --- a/presto-ml/src/test/java/com/facebook/presto/ml/TestLearnAggregations.java +++ b/presto-ml/src/test/java/com/facebook/presto/ml/TestLearnAggregations.java @@ -14,7 +14,9 @@ package com.facebook.presto.ml; import com.facebook.presto.RowPageBuilder; +import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.ml.type.ClassifierParametricType; import com.facebook.presto.ml.type.ClassifierType; import com.facebook.presto.ml.type.ModelType; @@ -34,6 +36,7 @@ import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.type.TypeRegistry; import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.collect.ImmutableList; @@ -58,6 +61,10 @@ public class TestLearnAggregations typeRegistry.addParametricType(new ClassifierParametricType()); typeRegistry.addType(ModelType.MODEL); typeRegistry.addType(RegressorType.REGRESSOR); + + // associate typeRegistry with a function registry + new FunctionRegistry(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); + typeManager = typeRegistry; } diff --git a/presto-mongodb/pom.xml b/presto-mongodb/pom.xml index c4a0693a958b8..4fab26610799d 100644 --- a/presto-mongodb/pom.xml +++ b/presto-mongodb/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-mongodb diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java index 43ca7cbe97660..f29891953d217 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSource.java @@ -214,18 +214,18 @@ private void writeBlock(BlockBuilder output, Type type, Object value) { if (isArrayType(type)) { if (value instanceof List) { - BlockBuilder builder = createParametersBlockBuilder(type, ((List) value).size()); + BlockBuilder builder = output.beginBlockEntry(); ((List) value).forEach(element -> appendTo(type.getTypeParameters().get(0), element, builder)); - type.writeObject(output, builder.build()); + output.closeEntry(); return; } } else if (isMapType(type)) { if (value instanceof List) { - BlockBuilder builder = createParametersBlockBuilder(type, ((List) value).size()); + BlockBuilder builder = output.beginBlockEntry(); for (Object element : (List) value) { if (!(element instanceof Map)) { continue; @@ -238,14 +238,14 @@ else if (isMapType(type)) { } } - type.writeObject(output, builder.build()); + output.closeEntry(); return; } } else if (isRowType(type)) { if (value instanceof Map) { Map mapValue = (Map) value; - BlockBuilder builder = createParametersBlockBuilder(type, mapValue.size()); + BlockBuilder builder = output.beginBlockEntry(); List fieldNames = type.getTypeSignature().getParameters().stream() .map(TypeSignatureParameter::getNamedTypeSignature) .map(NamedTypeSignature::getName) @@ -254,12 +254,12 @@ else if (isRowType(type)) { for (int index = 0; index < type.getTypeParameters().size(); index++) { appendTo(type.getTypeParameters().get(index), mapValue.get(fieldNames.get(index).toString()), builder); } - type.writeObject(output, builder.build()); + output.closeEntry(); return; } else if (value instanceof List) { List listValue = (List) value; - BlockBuilder builder = createParametersBlockBuilder(type, listValue.size()); + BlockBuilder builder = output.beginBlockEntry(); for (int index = 0; index < type.getTypeParameters().size(); index++) { if (index < listValue.size()) { appendTo(type.getTypeParameters().get(index), listValue.get(index), builder); @@ -268,7 +268,7 @@ else if (value instanceof List) { builder.appendNull(); } } - type.writeObject(output, builder.build()); + output.closeEntry(); return; } } diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java index 0b1f07785d274..0a91a95367ed9 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java @@ -145,9 +145,9 @@ public Set getAllTables(String schema) ImmutableSet.Builder builder = ImmutableSet.builder(); builder.addAll(ImmutableList.copyOf(client.getDatabase(schema).listCollectionNames()).stream() - .filter(name -> !name.equals(schemaCollection)) - .filter(name -> !SYSTEM_TABLES.contains(name)) - .collect(toSet())); + .filter(name -> !name.equals(schemaCollection)) + .filter(name -> !SYSTEM_TABLES.contains(name)) + .collect(toSet())); builder.addAll(getTableMetadataNames(schema)); return builder.build(); @@ -270,6 +270,7 @@ static Document buildQuery(TupleDomain tupleDomain) private static Document buildPredicate(MongoColumnHandle column, Domain domain) { String name = column.getName(); + Type type = column.getType(); if (domain.getValues().isNone() && domain.isNullAllowed()) { return documentOf(name, isNullPredicate()); } @@ -287,30 +288,30 @@ private static Document buildPredicate(MongoColumnHandle column, Domain domain) Document rangeConjuncts = new Document(); if (!range.getLow().isLowerUnbounded()) { switch (range.getLow().getBound()) { - case ABOVE: - rangeConjuncts.put(GT_OP, range.getLow().getValue()); - break; - case EXACTLY: - rangeConjuncts.put(GTE_OP, range.getLow().getValue()); - break; - case BELOW: - throw new IllegalArgumentException("Low Marker should never use BELOW bound: " + range); - default: - throw new AssertionError("Unhandled bound: " + range.getLow().getBound()); + case ABOVE: + rangeConjuncts.put(GT_OP, range.getLow().getValue()); + break; + case EXACTLY: + rangeConjuncts.put(GTE_OP, range.getLow().getValue()); + break; + case BELOW: + throw new IllegalArgumentException("Low Marker should never use BELOW bound: " + range); + default: + throw new AssertionError("Unhandled bound: " + range.getLow().getBound()); } } if (!range.getHigh().isUpperUnbounded()) { switch (range.getHigh().getBound()) { - case ABOVE: - throw new IllegalArgumentException("High Marker should never use ABOVE bound: " + range); - case EXACTLY: - rangeConjuncts.put(LTE_OP, range.getHigh().getValue()); - break; - case BELOW: - rangeConjuncts.put(LT_OP, range.getHigh().getValue()); - break; - default: - throw new AssertionError("Unhandled bound: " + range.getHigh().getBound()); + case ABOVE: + throw new IllegalArgumentException("High Marker should never use ABOVE bound: " + range); + case EXACTLY: + rangeConjuncts.put(LTE_OP, range.getHigh().getValue()); + break; + case BELOW: + rangeConjuncts.put(LT_OP, range.getHigh().getValue()); + break; + default: + throw new AssertionError("Unhandled bound: " + range.getHigh().getBound()); } } // If rangeConjuncts is null, then the range was ALL, which should already have been checked for @@ -321,11 +322,11 @@ private static Document buildPredicate(MongoColumnHandle column, Domain domain) // Add back all of the possible single values either as an equality or an IN predicate if (singleValues.size() == 1) { - disjuncts.add(documentOf(EQ_OP, translateValue(singleValues.get(0)))); + disjuncts.add(documentOf(EQ_OP, translateValue(singleValues.get(0), type))); } else if (singleValues.size() > 1) { disjuncts.add(documentOf(IN_OP, singleValues.stream() - .map(MongoSession::translateValue) + .map(value -> translateValue(value, type)) .collect(toList()))); } @@ -334,14 +335,19 @@ else if (singleValues.size() > 1) { } return orPredicate(disjuncts.stream() - .map(disjunct -> new Document(name, disjunct)) - .collect(toList())); + .map(disjunct -> new Document(name, disjunct)) + .collect(toList())); } - private static Object translateValue(Object source) + private static Object translateValue(Object source, Type type) { if (source instanceof Slice) { - return ((Slice) source).toStringUtf8(); + if (type instanceof ObjectIdType) { + return new ObjectId(((Slice) source).getBytes()); + } + else { + return ((Slice) source).toStringUtf8(); + } } return source; @@ -442,8 +448,8 @@ private void createTableMetadata(SchemaTableName schemaTableName, List signatures = subTypes.stream().map(t -> t.get()).collect(toSet()); if (signatures.size() == 1) { typeSignature = new TypeSignature(StandardTypes.ARRAY, signatures.stream() - .map(s -> TypeSignatureParameter.of(s)) - .collect(Collectors.toList())); + .map(s -> TypeSignatureParameter.of(s)) + .collect(Collectors.toList())); } else { // TODO: presto cli doesn't handle empty field name row type yet diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java index 78b62b70a6f1f..cce6774907494 100644 --- a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java @@ -70,7 +70,8 @@ public void createTableWithEveryType() ", 3.14 _double" + ", true _boolean" + ", DATE '1980-05-07' _date" + - ", TIMESTAMP '1980-05-07 11:22:33.456' _timestamp"; + ", TIMESTAMP '1980-05-07 11:22:33.456' _timestamp" + + ", ObjectId('ffffffffffffffffffffffff') _objectid"; assertUpdate(query, 1); @@ -161,6 +162,14 @@ public void testCollectionNameContainsDots() assertUpdate("DROP TABLE \"tmp.dot1\""); } + @Test + public void testObjectIds() + throws Exception + { + assertUpdate("CREATE TABLE tmp_objectid AS SELECT ObjectId('ffffffffffffffffffffffff') AS id", 1); + assertOneNotNullResult("SELECT id FROM tmp_objectid WHERE id = ObjectId('ffffffffffffffffffffffff')"); + } + private void assertOneNotNullResult(String query) { MaterializedResult results = getQueryRunner().execute(getSession(), query).toJdbcTypes(); diff --git a/presto-mysql/pom.xml b/presto-mysql/pom.xml index 04bc4150ecbda..2f51f079565c0 100644 --- a/presto-mysql/pom.xml +++ b/presto-mysql/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-mysql diff --git a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClientModule.java b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClientModule.java index 00c601d26405c..0ae1425dc4d3f 100644 --- a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClientModule.java +++ b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClientModule.java @@ -16,19 +16,37 @@ import com.facebook.presto.plugin.jdbc.BaseJdbcConfig; import com.facebook.presto.plugin.jdbc.JdbcClient; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Scopes; +import com.mysql.jdbc.Driver; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import java.sql.SQLException; +import java.util.Properties; + +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.configuration.ConfigBinder.configBinder; public class MySqlClientModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + protected void setup(Binder binder) { binder.bind(JdbcClient.class).to(MySqlClient.class).in(Scopes.SINGLETON); - configBinder(binder).bindConfig(BaseJdbcConfig.class); + ensureCatalogIsEmpty(buildConfigObject(BaseJdbcConfig.class).getConnectionUrl()); configBinder(binder).bindConfig(MySqlConfig.class); } + + private static void ensureCatalogIsEmpty(String connectionUrl) + { + try { + Driver driver = new Driver(); + Properties urlProperties = driver.parseURL(connectionUrl, null); + checkArgument(urlProperties != null, "Invalid JDBC URL for MySQL connector"); + checkArgument(driver.database(urlProperties) == null, "Database (catalog) must not be specified in JDBC URL for MySQL connector"); + } + catch (SQLException e) { + throw new RuntimeException(e); + } + } } diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlPlugin.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlPlugin.java index 275e7bb126d74..0092db2b62da8 100644 --- a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlPlugin.java +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlPlugin.java @@ -29,6 +29,6 @@ public void testCreateConnector() { Plugin plugin = new MySqlPlugin(); ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); - factory.create("test", ImmutableMap.of("connection-url", "test"), new TestingConnectorContext()); + factory.create("test", ImmutableMap.of("connection-url", "jdbc:mysql://test"), new TestingConnectorContext()); } } diff --git a/presto-orc/pom.xml b/presto-orc/pom.xml index 79d2ce096273e..edd812e7d08bb 100644 --- a/presto-orc/pom.xml +++ b/presto-orc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-orc @@ -112,7 +112,7 @@ com.facebook.presto.hadoop - hadoop-cdh4 + hadoop-apache2 test diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/AbstractOrcDataSource.java b/presto-orc/src/main/java/com/facebook/presto/orc/AbstractOrcDataSource.java index 6b99305fbfce6..25e0ebaf03adc 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/AbstractOrcDataSource.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/AbstractOrcDataSource.java @@ -37,7 +37,7 @@ public abstract class AbstractOrcDataSource implements OrcDataSource { - private final String name; + private final OrcDataSourceId id; private final long size; private final DataSize maxMergeDistance; private final DataSize maxBufferSize; @@ -45,9 +45,9 @@ public abstract class AbstractOrcDataSource private long readTimeNanos; private long readBytes; - public AbstractOrcDataSource(String name, long size, DataSize maxMergeDistance, DataSize maxBufferSize, DataSize streamBufferSize) + public AbstractOrcDataSource(OrcDataSourceId id, long size, DataSize maxMergeDistance, DataSize maxBufferSize, DataSize streamBufferSize) { - this.name = requireNonNull(name, "name is null"); + this.id = requireNonNull(id, "id is null"); this.size = size; checkArgument(size >= 0, "size is negative"); @@ -60,6 +60,12 @@ public AbstractOrcDataSource(String name, long size, DataSize maxMergeDistance, protected abstract void readInternal(long position, byte[] buffer, int bufferOffset, int bufferLength) throws IOException; + @Override + public OrcDataSourceId getId() + { + return id; + } + @Override public final long getReadBytes() { @@ -177,7 +183,7 @@ private Map readLargeDiskRanges(Map @Override public final String toString() { - return name; + return id.toString(); } private class HdfsSliceLoader diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/CachingOrcDataSource.java b/presto-orc/src/main/java/com/facebook/presto/orc/CachingOrcDataSource.java index 1bc309036de1b..5247532123617 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/CachingOrcDataSource.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/CachingOrcDataSource.java @@ -41,6 +41,12 @@ public CachingOrcDataSource(OrcDataSource dataSource, RegionFinder regionFinder) this.cache = new byte[0]; } + @Override + public OrcDataSourceId getId() + { + return dataSource.getId(); + } + @Override public long getReadBytes() { diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/FileOrcDataSource.java b/presto-orc/src/main/java/com/facebook/presto/orc/FileOrcDataSource.java index e8b6baef20512..03cd9b6a30425 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/FileOrcDataSource.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/FileOrcDataSource.java @@ -28,7 +28,7 @@ public class FileOrcDataSource public FileOrcDataSource(File path, DataSize maxMergeDistance, DataSize maxReadSize, DataSize streamBufferSize) throws FileNotFoundException { - super(path.getPath(), path.length(), maxMergeDistance, maxReadSize, streamBufferSize); + super(new OrcDataSourceId(path.getPath()), path.length(), maxMergeDistance, maxReadSize, streamBufferSize); this.input = new RandomAccessFile(path, "r"); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcCorruptionException.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcCorruptionException.java index 183a138f8e721..6d4adfb290d77 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcCorruptionException.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcCorruptionException.java @@ -20,18 +20,23 @@ public class OrcCorruptionException extends IOException { - public OrcCorruptionException(String message) + public OrcCorruptionException(OrcDataSourceId orcDataSourceId, String message) { super(message); } - public OrcCorruptionException(String messageFormat, Object... args) + public OrcCorruptionException(OrcDataSourceId orcDataSourceId, String messageFormat, Object... args) { - super(format(messageFormat, args)); + super(formatMessage(orcDataSourceId, messageFormat, args)); } - public OrcCorruptionException(Throwable cause, String messageFormat, Object... args) + public OrcCorruptionException(Throwable cause, OrcDataSourceId orcDataSourceId, String messageFormat, Object... args) { - super(format(messageFormat, args), cause); + super(formatMessage(orcDataSourceId, messageFormat, args), cause); + } + + private static String formatMessage(OrcDataSourceId orcDataSourceId, String messageFormat, Object[] args) + { + return "Malformed ORC file. " + format(messageFormat, args) + " [" + orcDataSourceId + "]"; } } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcDataSource.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcDataSource.java index ae4546332a669..acef99082c650 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcDataSource.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcDataSource.java @@ -22,6 +22,8 @@ public interface OrcDataSource extends Closeable { + OrcDataSourceId getId(); + long getReadBytes(); long getReadTimeNanos(); diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/UserAgentRequestFilter.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcDataSourceId.java similarity index 50% rename from presto-jdbc/src/main/java/com/facebook/presto/jdbc/UserAgentRequestFilter.java rename to presto-orc/src/main/java/com/facebook/presto/orc/OrcDataSourceId.java index 7b0be8c475c12..3dbfd3e98030c 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/UserAgentRequestFilter.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcDataSourceId.java @@ -11,30 +11,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.jdbc; +package com.facebook.presto.orc; -import com.google.common.net.HttpHeaders; -import io.airlift.http.client.HttpRequestFilter; -import io.airlift.http.client.Request; +import java.util.Objects; -import static io.airlift.http.client.Request.Builder.fromRequest; import static java.util.Objects.requireNonNull; -class UserAgentRequestFilter - implements HttpRequestFilter +public final class OrcDataSourceId { - private final String userAgent; + private final String id; - public UserAgentRequestFilter(String userAgent) + public OrcDataSourceId(String id) { - this.userAgent = requireNonNull(userAgent, "userAgent is null"); + this.id = requireNonNull(id, "id is null"); } @Override - public Request filterRequest(Request request) + public boolean equals(Object o) { - return fromRequest(request) - .addHeader(HttpHeaders.USER_AGENT, userAgent) - .build(); + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + OrcDataSourceId that = (OrcDataSourceId) o; + return Objects.equals(id, that.id); + } + + @Override + public int hashCode() + { + return Objects.hash(id); + } + + @Override + public String toString() + { + return id; } } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java index 4e281ef712755..8adba7b45e435 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcReader.java @@ -15,6 +15,7 @@ import com.facebook.presto.orc.memory.AbstractAggregatedMemoryContext; import com.facebook.presto.orc.memory.AggregatedMemoryContext; +import com.facebook.presto.orc.metadata.ExceptionWrappingMetadataReader; import com.facebook.presto.orc.metadata.Footer; import com.facebook.presto.orc.metadata.Metadata; import com.facebook.presto.orc.metadata.MetadataReader; @@ -52,9 +53,10 @@ public class OrcReader private static final int EXPECTED_FOOTER_SIZE = 16 * 1024; private final OrcDataSource orcDataSource; - private final MetadataReader metadataReader; + private final ExceptionWrappingMetadataReader metadataReader; private final DataSize maxMergeDistance; private final DataSize maxReadSize; + private final DataSize maxBlockSize; private final HiveWriterVersion hiveWriterVersion; private final int bufferSize; private final Footer footer; @@ -62,14 +64,15 @@ public class OrcReader private Optional decompressor = Optional.empty(); // This is based on the Apache Hive ORC code - public OrcReader(OrcDataSource orcDataSource, MetadataReader metadataReader, DataSize maxMergeDistance, DataSize maxReadSize) + public OrcReader(OrcDataSource orcDataSource, MetadataReader delegate, DataSize maxMergeDistance, DataSize maxReadSize, DataSize maxBlockSize) throws IOException { orcDataSource = wrapWithCacheIfTiny(requireNonNull(orcDataSource, "orcDataSource is null"), maxMergeDistance); this.orcDataSource = orcDataSource; - this.metadataReader = requireNonNull(metadataReader, "metadataReader is null"); + this.metadataReader = new ExceptionWrappingMetadataReader(orcDataSource.getId(), requireNonNull(delegate, "delegate is null")); this.maxMergeDistance = requireNonNull(maxMergeDistance, "maxMergeDistance is null"); this.maxReadSize = requireNonNull(maxReadSize, "maxReadSize is null"); + this.maxBlockSize = requireNonNull(maxBlockSize, "maxBlockSize is null"); // // Read the file tail: @@ -83,7 +86,7 @@ public OrcReader(OrcDataSource orcDataSource, MetadataReader metadataReader, Dat // figure out the size of the file using the option or filesystem long size = orcDataSource.getSize(); if (size <= 0) { - throw new OrcCorruptionException("Malformed ORC file %s. Invalid file size %s", orcDataSource, size); + throw new OrcCorruptionException(orcDataSource.getId(), "Invalid file size %s", size); } // Read the tail of the file @@ -110,13 +113,13 @@ public OrcReader(OrcDataSource orcDataSource, MetadataReader metadataReader, Dat case UNCOMPRESSED: break; case ZLIB: - decompressor = Optional.of(new OrcZlibDecompressor(bufferSize)); + decompressor = Optional.of(new OrcZlibDecompressor(orcDataSource.getId(), bufferSize)); break; case SNAPPY: - decompressor = Optional.of(new OrcSnappyDecompressor(bufferSize)); + decompressor = Optional.of(new OrcSnappyDecompressor(orcDataSource.getId(), bufferSize)); break; case ZSTD: - decompressor = Optional.of(new OrcZstdDecompressor(bufferSize)); + decompressor = Optional.of(new OrcZstdDecompressor(orcDataSource.getId(), bufferSize)); break; default: throw new UnsupportedOperationException("Unsupported compression type: " + postScript.getCompression()); @@ -148,13 +151,13 @@ public OrcReader(OrcDataSource orcDataSource, MetadataReader metadataReader, Dat // read metadata Slice metadataSlice = completeFooterSlice.slice(0, metadataSize); - try (InputStream metadataInputStream = new OrcInputStream(orcDataSource.toString(), metadataSlice.getInput(), decompressor, new AggregatedMemoryContext())) { + try (InputStream metadataInputStream = new OrcInputStream(orcDataSource.getId(), metadataSlice.getInput(), decompressor, new AggregatedMemoryContext())) { this.metadata = metadataReader.readMetadata(hiveWriterVersion, metadataInputStream); } // read footer Slice footerSlice = completeFooterSlice.slice(metadataSize, footerSize); - try (InputStream footerInputStream = new OrcInputStream(orcDataSource.toString(), footerSlice.getInput(), decompressor, new AggregatedMemoryContext())) { + try (InputStream footerInputStream = new OrcInputStream(orcDataSource.getId(), footerSlice.getInput(), decompressor, new AggregatedMemoryContext())) { this.footer = metadataReader.readFooter(hiveWriterVersion, footerInputStream); } } @@ -212,6 +215,7 @@ public OrcRecordReader createRecordReader( metadataReader, maxMergeDistance, maxReadSize, + maxBlockSize, footer.getUserMetadata(), systemMemoryUsage); } @@ -241,7 +245,7 @@ private static void verifyOrcFooter( { int magicLength = MAGIC.length(); if (postScriptSize < magicLength + 1) { - throw new OrcCorruptionException("Malformed ORC file %s. Invalid postscript length %s", source, postScriptSize); + throw new OrcCorruptionException(source.getId(), "Invalid postscript length %s", postScriptSize); } if (!MAGIC.equals(Slices.wrappedBuffer(buffer, buffer.length - 1 - magicLength, magicLength))) { @@ -251,7 +255,7 @@ private static void verifyOrcFooter( // if it isn't there, this isn't an ORC file if (!MAGIC.equals(Slices.wrappedBuffer(headerMagic))) { - throw new OrcCorruptionException("Malformed ORC file %s. Invalid postscript.", source); + throw new OrcCorruptionException(source.getId(), "Invalid postscript"); } } } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java index 9f087372ddbef..ab924c00d95f2 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcRecordReader.java @@ -53,6 +53,7 @@ import static com.facebook.presto.orc.OrcReader.MAX_BATCH_SIZE; import static com.facebook.presto.orc.OrcRecordReader.LinearProbeRangeFinder.createTinyStripesRangeFinder; import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.util.Comparator.comparingLong; @@ -64,13 +65,17 @@ public class OrcRecordReader private final OrcDataSource orcDataSource; private final StreamReader[] streamReaders; + private final long[] maxBytesPerCell; + private long maxCombinedBytesPerRow; private final long totalRowCount; private final long splitLength; private final Set presentColumns; + private final long maxBlockBytes; private long currentPosition; private long currentStripePosition; private int currentBatchSize; + private int maxBatchSize = MAX_BATCH_SIZE; private final List stripes; private final StripeReader stripeReader; @@ -107,6 +112,7 @@ public OrcRecordReader( MetadataReader metadataReader, DataSize maxMergeDistance, DataSize maxReadSize, + DataSize maxBlockSize, Map userMetadata, AbstractAggregatedMemoryContext systemMemoryUsage) throws IOException @@ -135,6 +141,8 @@ public OrcRecordReader( } this.presentColumns = presentColumns.build(); + this.maxBlockBytes = requireNonNull(maxBlockSize, "maxBlockSize is null").toBytes(); + // it is possible that old versions of orc use 0 to mean there are no row groups checkArgument(rowsInRowGroup > 0, "rowsInRowGroup must be greater than zero"); @@ -195,6 +203,7 @@ public OrcRecordReader( metadataReader); streamReaders = createStreamReaders(orcDataSource, types, hiveStorageTimeZone, presentColumnsAndTypes.build()); + maxBytesPerCell = new long[streamReaders.length]; } private static boolean splitContainsStripe(long splitOffset, long splitLength, StripeInformation stripe) @@ -279,6 +288,14 @@ public long getSplitLength() return splitLength; } + /** + * Returns the sum of the largest cells in size from each column + */ + public long getMaxCombinedBytesPerRow() + { + return maxCombinedBytesPerRow; + } + @Override public void close() throws IOException @@ -308,7 +325,7 @@ public int nextBatch() } } - currentBatchSize = toIntExact(min(MAX_BATCH_SIZE, currentGroupRowCount - nextRowInGroup)); + currentBatchSize = toIntExact(min(maxBatchSize, currentGroupRowCount - nextRowInGroup)); for (StreamReader column : streamReaders) { if (column != null) { @@ -322,7 +339,16 @@ public int nextBatch() public Block readBlock(Type type, int columnIndex) throws IOException { - return streamReaders[columnIndex].readBlock(type); + Block block = streamReaders[columnIndex].readBlock(type); + if (block.getPositionCount() > 0) { + long bytesPerCell = block.getSizeInBytes() / block.getPositionCount(); + if (maxBytesPerCell[columnIndex] < bytesPerCell) { + maxCombinedBytesPerRow = maxCombinedBytesPerRow - maxBytesPerCell[columnIndex] + bytesPerCell; + maxBytesPerCell[columnIndex] = bytesPerCell; + maxBatchSize = toIntExact(min(maxBatchSize, max(1, maxBlockBytes / maxCombinedBytesPerRow))); + } + } + return block; } public StreamReader getStreamReader(int index) @@ -400,7 +426,8 @@ private void advanceToNextStripe() } } - private static StreamReader[] createStreamReaders(OrcDataSource orcDataSource, + private static StreamReader[] createStreamReaders( + OrcDataSource orcDataSource, List types, DateTimeZone hiveStorageTimeZone, Map includedColumns) diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcSnappyDecompressor.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcSnappyDecompressor.java index aec7d6216c9cd..54d82e543f82c 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcSnappyDecompressor.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcSnappyDecompressor.java @@ -13,19 +13,22 @@ */ package com.facebook.presto.orc; +import io.airlift.compress.MalformedInputException; import io.airlift.compress.snappy.SnappyDecompressor; -import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; +import static java.util.Objects.requireNonNull; class OrcSnappyDecompressor implements OrcDecompressor { + private final OrcDataSourceId orcDataSourceId; private final int maxBufferSize; private final SnappyDecompressor decompressor = new SnappyDecompressor(); - public OrcSnappyDecompressor(int maxBufferSize) + public OrcSnappyDecompressor(OrcDataSourceId orcDataSourceId, int maxBufferSize) { + this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSourceId is null"); this.maxBufferSize = maxBufferSize; } @@ -33,13 +36,20 @@ public OrcSnappyDecompressor(int maxBufferSize) public int decompress(byte[] input, int offset, int length, OutputBuffer output) throws OrcCorruptionException { - int uncompressedLength = SnappyDecompressor.getUncompressedLength(input, offset); - checkArgument(uncompressedLength <= maxBufferSize, "Snappy requires buffer (%s) larger than max size (%s)", uncompressedLength, maxBufferSize); + try { + int uncompressedLength = SnappyDecompressor.getUncompressedLength(input, offset); + if (uncompressedLength > maxBufferSize) { + throw new OrcCorruptionException(orcDataSourceId, "Snappy requires buffer (%s) larger than max size (%s)", uncompressedLength, maxBufferSize); + } - // Snappy decompressor is more if there's at least a long's worth of extra space - // in the output buffer - byte[] buffer = output.initialize(uncompressedLength + SIZE_OF_LONG); - return decompressor.decompress(input, offset, length, buffer, 0, buffer.length); + // Snappy decompressor is more efficient if there's at least a long's worth of extra space + // in the output buffer + byte[] buffer = output.initialize(uncompressedLength + SIZE_OF_LONG); + return decompressor.decompress(input, offset, length, buffer, 0, buffer.length); + } + catch (MalformedInputException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid compressed stream"); + } } @Override diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcZlibDecompressor.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcZlibDecompressor.java index 7be017040a409..e1dc37a0b0d18 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcZlibDecompressor.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcZlibDecompressor.java @@ -16,15 +16,19 @@ import java.util.zip.DataFormatException; import java.util.zip.Inflater; +import static java.util.Objects.requireNonNull; + class OrcZlibDecompressor implements OrcDecompressor { private static final int EXPECTED_COMPRESSION_RATIO = 5; + private final OrcDataSourceId orcDataSourceId; private final int maxBufferSize; - public OrcZlibDecompressor(int maxBufferSize) + public OrcZlibDecompressor(OrcDataSourceId orcDataSourceId, int maxBufferSize) { + this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSourceId is null"); this.maxBufferSize = maxBufferSize; } @@ -50,13 +54,13 @@ public int decompress(byte[] input, int offset, int length, OutputBuffer output) } if (!inflater.finished()) { - throw new OrcCorruptionException("Could not decompress all input (output buffer too small?)"); + throw new OrcCorruptionException(orcDataSourceId, "Could not decompress all input (output buffer too small?)"); } return uncompressedLength; } catch (DataFormatException e) { - throw new OrcCorruptionException(e, "Invalid compressed stream"); + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid compressed stream"); } finally { inflater.end(); diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java b/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java index 4e1f3552fbca8..256db7b8f6351 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/OrcZstdDecompressor.java @@ -14,17 +14,21 @@ package com.facebook.presto.orc; import com.facebook.presto.orc.zstd.ZstdDecompressor; +import io.airlift.compress.MalformedInputException; -import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.StrictMath.toIntExact; +import static java.util.Objects.requireNonNull; class OrcZstdDecompressor implements OrcDecompressor { + private final OrcDataSourceId orcDataSourceId; private final int maxBufferSize; private final ZstdDecompressor decompressor = new ZstdDecompressor(); - public OrcZstdDecompressor(int maxBufferSize) + public OrcZstdDecompressor(OrcDataSourceId orcDataSourceId, int maxBufferSize) { + this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSourceId is null"); this.maxBufferSize = maxBufferSize; } @@ -32,11 +36,18 @@ public OrcZstdDecompressor(int maxBufferSize) public int decompress(byte[] input, int offset, int length, OutputBuffer output) throws OrcCorruptionException { - int uncompressedLength = (int) ZstdDecompressor.getDecompressedSize(input, offset, length); - checkArgument(uncompressedLength <= maxBufferSize, "Zstd requires buffer (%s) larger than max size (%s)", uncompressedLength, maxBufferSize); + try { + long uncompressedLength = ZstdDecompressor.getDecompressedSize(input, offset, length); + if (uncompressedLength > maxBufferSize) { + throw new OrcCorruptionException(orcDataSourceId, "Zstd requires buffer (%s) larger than max size (%s)", uncompressedLength, maxBufferSize); + } - byte[] buffer = output.initialize(uncompressedLength); - return decompressor.decompress(input, offset, length, buffer, 0, buffer.length); + byte[] buffer = output.initialize(toIntExact(uncompressedLength)); + return decompressor.decompress(input, offset, length, buffer, 0, buffer.length); + } + catch (MalformedInputException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid compressed stream"); + } } @Override diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/StreamDescriptor.java b/presto-orc/src/main/java/com/facebook/presto/orc/StreamDescriptor.java index e47aa13599cd1..c1842afdc8521 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/StreamDescriptor.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/StreamDescriptor.java @@ -27,16 +27,16 @@ public final class StreamDescriptor private final int streamId; private final OrcTypeKind streamType; private final String fieldName; - private final OrcDataSource fileInput; + private final OrcDataSource orcDataSource; private final List nestedStreams; - public StreamDescriptor(String streamName, int streamId, String fieldName, OrcTypeKind streamType, OrcDataSource fileInput, List nestedStreams) + public StreamDescriptor(String streamName, int streamId, String fieldName, OrcTypeKind streamType, OrcDataSource orcDataSource, List nestedStreams) { this.streamName = requireNonNull(streamName, "streamName is null"); this.streamId = streamId; this.fieldName = requireNonNull(fieldName, "fieldName is null"); this.streamType = requireNonNull(streamType, "type is null"); - this.fileInput = requireNonNull(fileInput, "fileInput is null"); + this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null"); this.nestedStreams = ImmutableList.copyOf(requireNonNull(nestedStreams, "nestedStreams is null")); } @@ -60,9 +60,14 @@ public String getFieldName() return fieldName; } - public OrcDataSource getFileInput() + public OrcDataSourceId getOrcDataSourceId() { - return fileInput; + return orcDataSource.getId(); + } + + public OrcDataSource getOrcDataSource() + { + return orcDataSource; } public List getNestedStreams() @@ -77,7 +82,7 @@ public String toString() .add("streamName", streamName) .add("streamId", streamId) .add("streamType", streamType) - .add("path", fileInput) + .add("dataSource", orcDataSource.getId()) .toString(); } } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/StripeReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/StripeReader.java index b297be9a59312..e636342324fc3 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/StripeReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/StripeReader.java @@ -167,7 +167,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste // If the file does not have a row group dictionary, treat the stripe as a single row group. Otherwise, // we must fail because the length of the row group dictionary is contained in the checkpoint stream. if (hasRowGroupDictionary) { - throw new OrcCorruptionException(e, "ORC file %s has corrupt checkpoints", orcDataSource); + throw new OrcCorruptionException(e, orcDataSource.getId(), "Checkpoints are corrupt"); } } } @@ -220,10 +220,9 @@ public Map readDiskRanges(long stripeOffset, Map streamsData = orcDataSource.readFully(diskRanges); // transform streams to OrcInputStream - String sourceName = orcDataSource.toString(); ImmutableMap.Builder streamsBuilder = ImmutableMap.builder(); for (Entry entry : streamsData.entrySet()) { - streamsBuilder.put(entry.getKey(), new OrcInputStream(sourceName, entry.getValue(), decompressor, systemMemoryUsage)); + streamsBuilder.put(entry.getKey(), new OrcInputStream(orcDataSource.getId(), entry.getValue(), decompressor, systemMemoryUsage)); } return streamsBuilder.build(); } @@ -327,7 +326,7 @@ public StripeFooter readStripeFooter(StripeInformation stripe, AbstractAggregate // read the footer byte[] tailBuffer = new byte[tailLength]; orcDataSource.readFully(offset, tailBuffer); - try (InputStream inputStream = new OrcInputStream(orcDataSource.toString(), Slices.wrappedBuffer(tailBuffer).getInput(), decompressor, systemMemoryUsage)) { + try (InputStream inputStream = new OrcInputStream(orcDataSource.getId(), Slices.wrappedBuffer(tailBuffer).getInput(), decompressor, systemMemoryUsage)) { return metadataReader.readStripeFooter(hiveWriterVersion, types, inputStream); } } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/metadata/DwrfMetadataReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/metadata/DwrfMetadataReader.java index fd3e8ec742117..f85716583d6e1 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/metadata/DwrfMetadataReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/metadata/DwrfMetadataReader.java @@ -80,12 +80,17 @@ public Footer readFooter(HiveWriterVersion hiveWriterVersion, InputStream inputS { CodedInputStream input = CodedInputStream.newInstance(inputStream); DwrfProto.Footer footer = DwrfProto.Footer.parseFrom(input); + + // todo enable file stats when DWRF team verifies that the stats are correct + // List fileStats = toColumnStatistics(hiveWriterVersion, footer.getStatisticsList(), false); + List fileStats = ImmutableList.of(); + return new Footer( footer.getNumberOfRows(), footer.getRowIndexStride(), toStripeInformation(footer.getStripesList()), toType(footer.getTypesList()), - toColumnStatistics(hiveWriterVersion, footer.getStatisticsList(), false), + fileStats, toUserMetadata(footer.getMetadataList())); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/metadata/ExceptionWrappingMetadataReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/metadata/ExceptionWrappingMetadataReader.java new file mode 100644 index 0000000000000..686a3ac067343 --- /dev/null +++ b/presto-orc/src/main/java/com/facebook/presto/orc/metadata/ExceptionWrappingMetadataReader.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.orc.metadata; + +import com.facebook.presto.orc.OrcCorruptionException; +import com.facebook.presto.orc.OrcDataSourceId; +import com.facebook.presto.orc.metadata.PostScript.HiveWriterVersion; +import com.facebook.presto.orc.metadata.statistics.HiveBloomFilter; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class ExceptionWrappingMetadataReader + implements MetadataReader +{ + private final OrcDataSourceId orcDataSourceId; + private final MetadataReader delegate; + + public ExceptionWrappingMetadataReader(OrcDataSourceId orcDataSourceId, MetadataReader delegate) + { + this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSourceId is null"); + this.delegate = requireNonNull(delegate, "delegate is null"); + checkArgument(!(delegate instanceof ExceptionWrappingMetadataReader), "ExceptionWrappingMetadataReader can not wrap a ExceptionWrappingMetadataReader"); + } + + @Override + public PostScript readPostScript(byte[] data, int offset, int length) + throws OrcCorruptionException + { + try { + return delegate.readPostScript(data, offset, length); + } + catch (IOException | RuntimeException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid postscript"); + } + } + + @Override + public Metadata readMetadata(HiveWriterVersion hiveWriterVersion, InputStream inputStream) + throws OrcCorruptionException + { + try { + return delegate.readMetadata(hiveWriterVersion, inputStream); + } + catch (IOException | RuntimeException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid file metadata"); + } + } + + @Override + public Footer readFooter(HiveWriterVersion hiveWriterVersion, InputStream inputStream) + throws OrcCorruptionException + { + try { + return delegate.readFooter(hiveWriterVersion, inputStream); + } + catch (IOException | RuntimeException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid file footer"); + } + } + + @Override + public StripeFooter readStripeFooter(HiveWriterVersion hiveWriterVersion, List types, InputStream inputStream) + throws OrcCorruptionException + { + try { + return delegate.readStripeFooter(hiveWriterVersion, types, inputStream); + } + catch (IOException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid stripe footer"); + } + } + + @Override + public List readRowIndexes(HiveWriterVersion hiveWriterVersion, InputStream inputStream) + throws OrcCorruptionException + { + try { + return delegate.readRowIndexes(hiveWriterVersion, inputStream); + } + catch (IOException | RuntimeException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid stripe row index"); + } + } + + @Override + public List readBloomFilterIndexes(InputStream inputStream) + throws OrcCorruptionException + { + try { + return delegate.readBloomFilterIndexes(inputStream); + } + catch (IOException | RuntimeException e) { + throw new OrcCorruptionException(e, orcDataSourceId, "Invalid bloom filter"); + } + } +} diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/BooleanStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/BooleanStreamReader.java index f41b5b437a9ae..15d777c8117c1 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/BooleanStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/BooleanStreamReader.java @@ -85,7 +85,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.skip(readOffset); } @@ -94,7 +94,7 @@ public Block readBlock(Type type) BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), nextBatchSize); if (presentStream == null) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.getSetBits(type, nextBatchSize, builder); } @@ -105,7 +105,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.getSetBits(type, nextBatchSize, builder, nullVector); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/ByteStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/ByteStreamReader.java index 7a062894b6451..4bd77beeaf442 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/ByteStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/ByteStreamReader.java @@ -86,7 +86,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.skip(readOffset); } @@ -95,7 +95,7 @@ public Block readBlock(Type type) BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), nextBatchSize); if (presentStream == null) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextVector(type, nextBatchSize, builder); } @@ -106,7 +106,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextVector(type, nextBatchSize, builder, nullVector); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/DecimalStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/DecimalStreamReader.java index 5709f6f56a30b..bb6cc05ae148b 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/DecimalStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/DecimalStreamReader.java @@ -97,10 +97,10 @@ public Block readBlock(Type type) BlockBuilder builder = decimalType.createBlockBuilder(new BlockBuilderStatus(), nextBatchSize); if (presentStream == null) { if (decimalStream == null) { - throw new OrcCorruptionException("Value is not null but decimal stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but decimal stream is not present"); } if (scaleStream == null) { - throw new OrcCorruptionException("Value is not null but scale stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but scale stream is not present"); } Arrays.fill(nullVector, false); @@ -117,10 +117,10 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (decimalStream == null) { - throw new OrcCorruptionException("Value is not null but decimal stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but decimal stream is not present"); } if (scaleStream == null) { - throw new OrcCorruptionException("Value is not null but scale stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but scale stream is not present"); } scaleStream.nextLongVector(nextBatchSize, scaleVector, nullVector); @@ -165,10 +165,10 @@ private void seekToOffset() } if (readOffset > 0) { if (decimalStream == null) { - throw new OrcCorruptionException("Value is not null but decimal stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but decimal stream is not present"); } if (scaleStream == null) { - throw new OrcCorruptionException("Value is not null but scale stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but scale stream is not present"); } decimalStream.skip(readOffset); diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/DoubleStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/DoubleStreamReader.java index 82ddd139ca06f..8ba7eb7e98f02 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/DoubleStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/DoubleStreamReader.java @@ -86,7 +86,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.skip(readOffset); } @@ -95,7 +95,7 @@ public Block readBlock(Type type) BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), nextBatchSize); if (presentStream == null) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextVector(type, nextBatchSize, builder); } @@ -106,7 +106,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextVector(type, nextBatchSize, builder, nullVector); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/FloatStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/FloatStreamReader.java index ae8b6cc7dac7e..1b7b07efd1edc 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/FloatStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/FloatStreamReader.java @@ -86,7 +86,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.skip(readOffset); } @@ -95,7 +95,7 @@ public Block readBlock(Type type) BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), nextBatchSize); if (presentStream == null) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextVector(type, nextBatchSize, builder); } @@ -106,7 +106,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextVector(type, nextBatchSize, builder, nullVector); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/ListStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/ListStreamReader.java index 6764898201bb9..8ee1db16a862b 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/ListStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/ListStreamReader.java @@ -91,7 +91,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } long elementSkipSize = lengthStream.sum(readOffset); elementStreamReader.prepareNextRead(toIntExact(elementSkipSize)); @@ -106,7 +106,7 @@ public Block readBlock(Type type) boolean[] nullVector = new boolean[nextBatchSize]; if (presentStream == null) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.nextIntVector(nextBatchSize, lengthVector); } @@ -114,7 +114,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.nextIntVector(nextBatchSize, lengthVector, nullVector); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDictionaryStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDictionaryStreamReader.java index 29e22b90e80d4..a207ca41d55e6 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDictionaryStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDictionaryStreamReader.java @@ -108,7 +108,7 @@ public Block readBlock(Type type) if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.skip(readOffset); } @@ -122,7 +122,7 @@ public Block readBlock(Type type) } if (presentStream == null) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } Arrays.fill(nullVector, false); dataStream.nextLongVector(nextBatchSize, dataVector); @@ -131,7 +131,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextLongVector(nextBatchSize, dataVector, nullVector); } @@ -177,7 +177,7 @@ private void openRowGroup() LongInputStream dictionaryStream = dictionaryDataStreamSource.openStream(); if (dictionaryStream == null) { - throw new OrcCorruptionException("Dictionary is not empty but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Dictionary is not empty but data stream is not present"); } dictionaryStream.nextLongVector(dictionarySize, dictionary); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDirectStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDirectStreamReader.java index 1e484f28574c1..a00db0b000145 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDirectStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/LongDirectStreamReader.java @@ -86,7 +86,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.skip(readOffset); } @@ -95,7 +95,7 @@ public Block readBlock(Type type) BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), nextBatchSize); if (presentStream == null) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextLongVector(type, nextBatchSize, builder); } @@ -106,7 +106,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextLongVector(type, nextBatchSize, builder, nullVector); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java index 34994989acd17..3514ede45dbdb 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/MapStreamReader.java @@ -20,10 +20,9 @@ import com.facebook.presto.orc.stream.InputStreamSource; import com.facebook.presto.orc.stream.InputStreamSources; import com.facebook.presto.orc.stream.LongInputStream; -import com.facebook.presto.spi.block.ArrayBlock; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlock; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrayList; import org.joda.time.DateTimeZone; @@ -95,7 +94,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } long entrySkipSize = lengthStream.sum(readOffset); keyStreamReader.prepareNextRead(toIntExact(entrySkipSize)); @@ -111,7 +110,7 @@ public Block readBlock(Type type) boolean[] nullVector = new boolean[nextBatchSize]; if (presentStream == null) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.nextIntVector(nextBatchSize, lengthVector); } @@ -119,14 +118,15 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.nextIntVector(nextBatchSize, lengthVector, nullVector); } } - Type keyType = type.getTypeParameters().get(0); - Type valueType = type.getTypeParameters().get(1); + MapType mapType = (MapType) type; + Type keyType = mapType.getKeyType(); + Type valueType = mapType.getValueType(); int entryCount = 0; for (int length : lengthVector) { @@ -146,26 +146,25 @@ public Block readBlock(Type type) values = valueType.createBlockBuilder(new BlockBuilderStatus(), 1).build(); } - InterleavedBlock keyValueBlock = createKeyValueBlock(nextBatchSize, keys, values, lengthVector); + Block[] keyValueBlock = createKeyValueBlock(nextBatchSize, keys, values, lengthVector); // convert lengths into offsets into the keyValueBlock (e.g., two positions per entry) int[] offsets = new int[nextBatchSize + 1]; for (int i = 1; i < offsets.length; i++) { - int length = lengthVector[i - 1] * 2; + int length = lengthVector[i - 1]; offsets[i] = offsets[i - 1] + length; } - ArrayBlock arrayBlock = new ArrayBlock(nextBatchSize, nullVector, offsets, keyValueBlock); readOffset = 0; nextBatchSize = 0; - return arrayBlock; + return mapType.createBlockFromKeyValue(nullVector, offsets, keyValueBlock[0], keyValueBlock[1]); } - private static InterleavedBlock createKeyValueBlock(int positionCount, Block keys, Block values, int[] lengths) + private static Block[] createKeyValueBlock(int positionCount, Block keys, Block values, int[] lengths) { if (!hasNull(keys)) { - return new InterleavedBlock(new Block[] {keys, values}); + return new Block[] {keys, values}; } // @@ -191,7 +190,7 @@ private static InterleavedBlock createKeyValueBlock(int positionCount, Block key Block newKeys = keys.copyPositions(nonNullPositions); Block newValues = values.copyPositions(nonNullPositions); - return new InterleavedBlock(new Block[] {newKeys, newValues}); + return new Block[] {newKeys, newValues}; } private static boolean hasNull(Block keys) diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java index 8c86b00315c06..09b3709252d4e 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDictionaryStreamReader.java @@ -128,7 +128,7 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } if (inDictionaryStream != null) { inDictionaryStream.skip(readOffset); @@ -144,7 +144,7 @@ public Block readBlock(Type type) int[] dataVector = new int[nextBatchSize]; if (presentStream == null) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } Arrays.fill(isNullVector, false); dataStream.nextIntVector(nextBatchSize, dataVector); @@ -153,7 +153,7 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, isNullVector); if (nullValues != nextBatchSize) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.nextIntVector(nextBatchSize, dataVector, isNullVector); } @@ -184,7 +184,6 @@ else if (inDictionary[i]) { } } - // copy ids into a private array for this block since data vector is reused Block block = new DictionaryBlock(nextBatchSize, dictionaryBlock, dataVector); readOffset = 0; @@ -215,7 +214,7 @@ private void openRowGroup(Type type) // read the lengths LongInputStream lengthStream = stripeDictionaryLengthStreamSource.openStream(); if (lengthStream == null) { - throw new OrcCorruptionException("Dictionary is not empty but dictionary length stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Dictionary is not empty but dictionary length stream is not present"); } lengthStream.nextIntVector(stripeDictionarySize, dictionaryLength); diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java index b65363e8a74bb..5db980cb30435 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/SliceDirectStreamReader.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.type.Type; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.airlift.units.DataSize; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -43,12 +44,15 @@ import static com.facebook.presto.spi.type.Varchars.isVarcharType; import static com.facebook.presto.spi.type.Varchars.truncateToLength; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class SliceDirectStreamReader implements StreamReader { private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + private static final int ONE_GIGABYTE = toIntExact(new DataSize(1, GIGABYTE).toBytes()); private final StreamDescriptor streamDescriptor; @@ -102,12 +106,12 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but length stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but length stream is not present"); } long dataSkipSize = lengthStream.sum(readOffset); if (dataSkipSize > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } dataStream.skip(dataSkipSize); } @@ -122,7 +126,7 @@ public Block readBlock(Type type) } if (presentStream == null) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but length stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but length stream is not present"); } Arrays.fill(isNullVector, false); lengthStream.nextIntVector(nextBatchSize, lengthVector); @@ -131,25 +135,28 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, isNullVector); if (nullValues != nextBatchSize) { if (lengthStream == null) { - throw new OrcCorruptionException("Value is not null but length stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but length stream is not present"); } lengthStream.nextIntVector(nextBatchSize, lengthVector, isNullVector); } } - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < nextBatchSize; i++) { if (!isNullVector[i]) { totalLength += lengthVector[i]; } } + if (totalLength > ONE_GIGABYTE) { + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Column values too large to process in Presto. %s column values larger than 1GB", nextBatchSize); + } byte[] data = EMPTY_BYTE_ARRAY; if (totalLength > 0) { if (dataStream == null) { - throw new OrcCorruptionException("Value is not null but data stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); } - data = dataStream.next(totalLength); + data = dataStream.next(toIntExact(totalLength)); } Slice[] sliceVector = new Slice[nextBatchSize]; diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/reader/TimestampStreamReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/reader/TimestampStreamReader.java index 22aa5068b67ad..9c7e00c5a182f 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/reader/TimestampStreamReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/reader/TimestampStreamReader.java @@ -101,10 +101,10 @@ public Block readBlock(Type type) } if (readOffset > 0) { if (secondsStream == null) { - throw new OrcCorruptionException("Value is not null but seconds stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but seconds stream is not present"); } if (nanosStream == null) { - throw new OrcCorruptionException("Value is not null but nanos stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but nanos stream is not present"); } secondsStream.skip(readOffset); @@ -122,10 +122,10 @@ public Block readBlock(Type type) BlockBuilder builder = type.createBlockBuilder(new BlockBuilderStatus(), nextBatchSize); if (presentStream == null) { if (secondsStream == null) { - throw new OrcCorruptionException("Value is not null but seconds stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but seconds stream is not present"); } if (nanosStream == null) { - throw new OrcCorruptionException("Value is not null but nanos stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but nanos stream is not present"); } secondsStream.nextLongVector(nextBatchSize, secondsVector); @@ -143,10 +143,10 @@ public Block readBlock(Type type) int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (secondsStream == null) { - throw new OrcCorruptionException("Value is not null but seconds stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but seconds stream is not present"); } if (nanosStream == null) { - throw new OrcCorruptionException("Value is not null but nanos stream is not present"); + throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but nanos stream is not present"); } secondsStream.nextLongVector(nextBatchSize, secondsVector, nullVector); diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteArrayInputStream.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteArrayInputStream.java index ef2b90e5955d5..17c089f110506 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteArrayInputStream.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteArrayInputStream.java @@ -17,8 +17,6 @@ import java.io.IOException; -import static com.facebook.presto.orc.stream.OrcStreamUtils.readFully; -import static com.facebook.presto.orc.stream.OrcStreamUtils.skipFully; import static java.util.Objects.requireNonNull; public class ByteArrayInputStream @@ -35,14 +33,14 @@ public byte[] next(int length) throws IOException { byte[] data = new byte[length]; - readFully(inputStream, data, 0, length); + inputStream.readFully(data, 0, length); return data; } public void next(int length, byte[] data) throws IOException { - readFully(inputStream, data, 0, length); + inputStream.readFully(data, 0, length); } @Override @@ -62,6 +60,6 @@ public void seekToCheckpoint(ByteArrayStreamCheckpoint checkpoint) public void skip(long skipSize) throws IOException { - skipFully(inputStream, skipSize); + inputStream.skipFully(skipSize); } } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteInputStream.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteInputStream.java index e2f02a381df07..9b3987d5d37b1 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteInputStream.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/ByteInputStream.java @@ -21,12 +21,11 @@ import java.io.IOException; import java.util.Arrays; -import static com.facebook.presto.orc.stream.OrcStreamUtils.MIN_REPEAT_SIZE; -import static com.facebook.presto.orc.stream.OrcStreamUtils.readFully; - public class ByteInputStream implements ValueInputStream { + private static final int MIN_REPEAT_SIZE = 3; + private final OrcInputStream input; private final byte[] buffer = new byte[MIN_REPEAT_SIZE + 127]; private int length; @@ -47,7 +46,7 @@ private void readNextBlock() int control = input.read(); if (control == -1) { - throw new OrcCorruptionException("Read past end of buffer RLE byte from %s", input); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Read past end of buffer RLE byte"); } offset = 0; @@ -59,7 +58,7 @@ private void readNextBlock() // read the repeated value int value = input.read(); if (value == -1) { - throw new OrcCorruptionException("Reading RLE byte got EOF"); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Reading RLE byte got EOF"); } // fill buffer with the value @@ -70,7 +69,7 @@ private void readNextBlock() length = 0x100 - control; // read the literals into the buffer - readFully(input, buffer, 0, length); + input.readFully(buffer, 0, length); } } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/DecimalInputStream.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/DecimalInputStream.java index 637084cc92ab9..121b1589cde96 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/DecimalInputStream.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/DecimalInputStream.java @@ -59,7 +59,7 @@ public void nextLongDecimal(Slice result) do { b = input.read(); if (offset == 126 && ((b & 0x80) > 0 || (b & 0x7f) > 3)) { - throw new OrcCorruptionException("Decimal exceeds 128 bits"); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Decimal exceeds 128 bits"); } if (offset < 63) { @@ -135,11 +135,11 @@ public long nextLong() do { b = input.read(); if (b == -1) { - throw new OrcCorruptionException("Reading BigInteger past EOF from " + input); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Reading BigInteger past EOF"); } long work = 0x7f & b; if (offset >= 63 && (offset != 63 || work > 1)) { - throw new OrcCorruptionException("Decimal does not fit long (invalid table schema?)"); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Decimal does not fit long (invalid table schema?)"); } result |= work << offset; offset += 7; @@ -192,7 +192,7 @@ public void skip(long items) do { b = input.read(); if (b == -1) { - throw new OrcCorruptionException("Reading BigInteger past EOF from " + input); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Reading BigInteger past EOF"); } } while (b >= 0x80); diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/DoubleInputStream.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/DoubleInputStream.java index c598f24860b13..79bb2ac2f19fd 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/DoubleInputStream.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/DoubleInputStream.java @@ -21,8 +21,6 @@ import java.io.IOException; -import static com.facebook.presto.orc.stream.OrcStreamUtils.readFully; -import static com.facebook.presto.orc.stream.OrcStreamUtils.skipFully; import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; public class DoubleInputStream @@ -55,13 +53,13 @@ public void skip(long items) throws IOException { long length = items * SIZE_OF_DOUBLE; - skipFully(input, length); + input.skipFully(length); } public double next() throws IOException { - readFully(input, buffer, 0, SIZE_OF_DOUBLE); + input.readFully(buffer, 0, SIZE_OF_DOUBLE); return slice.getDouble(0); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/FloatInputStream.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/FloatInputStream.java index 87b41ba6dd8d5..f62af8d4d5383 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/FloatInputStream.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/FloatInputStream.java @@ -21,8 +21,6 @@ import java.io.IOException; -import static com.facebook.presto.orc.stream.OrcStreamUtils.readFully; -import static com.facebook.presto.orc.stream.OrcStreamUtils.skipFully; import static io.airlift.slice.SizeOf.SIZE_OF_FLOAT; import static java.lang.Float.floatToRawIntBits; @@ -56,13 +54,13 @@ public void skip(long items) throws IOException { long length = items * SIZE_OF_FLOAT; - skipFully(input, length); + input.skipFully(length); } public float next() throws IOException { - readFully(input, buffer, 0, SIZE_OF_FLOAT); + input.readFully(buffer, 0, SIZE_OF_FLOAT); return slice.getFloat(0); } diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongDecode.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongDecode.java index 7c6133b14bb65..e6e99771ac5a7 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongDecode.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongDecode.java @@ -17,7 +17,6 @@ import com.facebook.presto.orc.metadata.OrcType.OrcTypeKind; import java.io.IOException; -import java.io.InputStream; import static com.facebook.presto.orc.metadata.OrcType.OrcTypeKind.INT; import static com.facebook.presto.orc.metadata.OrcType.OrcTypeKind.LONG; @@ -119,14 +118,14 @@ else if (width > 48 && width <= 56) { } } - public static long readSignedVInt(InputStream inputStream) + public static long readSignedVInt(OrcInputStream inputStream) throws IOException { long result = readUnsignedVInt(inputStream); return (result >>> 1) ^ -(result & 1); } - public static long readUnsignedVInt(InputStream inputStream) + public static long readUnsignedVInt(OrcInputStream inputStream) throws IOException { long result = 0; @@ -135,7 +134,7 @@ public static long readUnsignedVInt(InputStream inputStream) do { b = inputStream.read(); if (b == -1) { - throw new OrcCorruptionException("EOF while reading unsigned vint"); + throw new OrcCorruptionException(inputStream.getOrcDataSourceId(), "EOF while reading unsigned vint"); } result |= (b & 0b0111_1111) << offset; offset += 7; @@ -143,7 +142,7 @@ public static long readUnsignedVInt(InputStream inputStream) return result; } - public static long readVInt(boolean signed, InputStream inputStream) + public static long readVInt(boolean signed, OrcInputStream inputStream) throws IOException { if (signed) { @@ -159,7 +158,7 @@ public static long zigzagDecode(long value) return (value >>> 1) ^ -(value & 1); } - public static long readDwrfLong(InputStream input, OrcTypeKind type, boolean signed, boolean usesVInt) + public static long readDwrfLong(OrcInputStream input, OrcTypeKind type, boolean signed, boolean usesVInt) throws IOException { if (usesVInt) { diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV1.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV1.java index 79b40ffe0c592..d94a430df723f 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV1.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV1.java @@ -21,12 +21,12 @@ import java.io.IOException; -import static com.facebook.presto.orc.stream.OrcStreamUtils.MIN_REPEAT_SIZE; import static java.lang.Math.toIntExact; public class LongInputStreamV1 implements LongInputStream { + private static final int MIN_REPEAT_SIZE = 3; private static final int MAX_LITERAL_SIZE = 128; private final OrcInputStream input; @@ -53,7 +53,7 @@ private void readValues() int control = input.read(); if (control == -1) { - throw new OrcCorruptionException("Read past end of RLE integer from %s", input); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Read past end of RLE integer"); } if (control < 0x80) { @@ -62,7 +62,7 @@ private void readValues() repeat = true; delta = input.read(); if (delta == -1) { - throw new OrcCorruptionException("End of stream in RLE Integer from %s", input); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "End of stream in RLE Integer"); } // convert from 0 to 255 to -128 to 127 by converting to a signed byte diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV2.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV2.java index 2e3c8d4ff30ae..9d64687932ade 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV2.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/LongInputStreamV2.java @@ -22,7 +22,6 @@ import java.io.IOException; import java.io.InputStream; -import static com.facebook.presto.orc.stream.OrcStreamUtils.MIN_REPEAT_SIZE; import static java.lang.Math.toIntExact; /** @@ -32,6 +31,7 @@ public class LongInputStreamV2 implements LongInputStream { + private static final int MIN_REPEAT_SIZE = 3; private static final int MAX_LITERAL_SIZE = 512; private enum EncodingType @@ -65,7 +65,7 @@ private void readValues() // read the first 2 bits and determine the encoding type int firstByte = input.read(); if (firstByte < 0) { - throw new OrcCorruptionException("Read past end of RLE integer from %s", input); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Read past end of RLE integer"); } int enc = (firstByte >>> 6) & 0x03; @@ -188,7 +188,7 @@ private void readPatchedBaseValues(int firstByte) long[] unpackedPatch = new long[patchListLength]; if ((patchWidth + patchGapWidth) > 64 && !skipCorrupt) { - throw new OrcCorruptionException("ORC file is corrupt"); + throw new OrcCorruptionException(input.getOrcDataSourceId(), "Invalid RLEv2 encoded stream"); } int bitSize = LongDecode.getClosestFixedBits(patchWidth + patchGapWidth); diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/OrcInputStream.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/OrcInputStream.java index 52ce66b8637cf..61f873f9a951b 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/OrcInputStream.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/stream/OrcInputStream.java @@ -14,6 +14,7 @@ package com.facebook.presto.orc.stream; import com.facebook.presto.orc.OrcCorruptionException; +import com.facebook.presto.orc.OrcDataSourceId; import com.facebook.presto.orc.OrcDecompressor; import com.facebook.presto.orc.memory.AbstractAggregatedMemoryContext; import com.facebook.presto.orc.memory.LocalMemoryContext; @@ -38,7 +39,7 @@ public final class OrcInputStream extends InputStream { - private final String source; + private final OrcDataSourceId orcDataSourceId; private final FixedLengthSliceInput compressedSliceInput; private final Optional decompressor; @@ -55,9 +56,9 @@ public final class OrcInputStream // * Memory pointed to by `current` is always part of `buffer`. It shouldn't be counted again. private final LocalMemoryContext fixedMemoryUsage; - public OrcInputStream(String source, FixedLengthSliceInput sliceInput, Optional decompressor, AbstractAggregatedMemoryContext systemMemoryContext) + public OrcInputStream(OrcDataSourceId orcDataSourceId, FixedLengthSliceInput sliceInput, Optional decompressor, AbstractAggregatedMemoryContext systemMemoryContext) { - this.source = requireNonNull(source, "source is null"); + this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSource is null"); requireNonNull(sliceInput, "sliceInput is null"); @@ -140,6 +141,35 @@ public int read(byte[] b, int off, int length) return current.read(b, off, length); } + public void skipFully(long length) + throws IOException + { + while (length > 0) { + long result = skip(length); + if (result < 0) { + throw new OrcCorruptionException(orcDataSourceId, "Unexpected end of stream"); + } + length -= result; + } + } + + public void readFully(byte[] buffer, int offset, int length) + throws IOException + { + while (offset < length) { + int result = read(buffer, offset, length - offset); + if (result < 0) { + throw new OrcCorruptionException(orcDataSourceId, "Unexpected end of stream"); + } + offset += result; + } + } + + public OrcDataSourceId getOrcDataSourceId() + { + return orcDataSourceId; + } + public long getCheckpoint() { // if the decompressed buffer is empty, return a checkpoint starting at the next block @@ -158,7 +188,7 @@ public boolean seekToCheckpoint(long checkpoint) boolean discardedBuffer; if (compressedBlockOffset != currentCompressedBlockOffset) { if (!decompressor.isPresent()) { - throw new OrcCorruptionException("Reset stream has a compressed block offset but stream is not compressed"); + throw new OrcCorruptionException(orcDataSourceId, "Reset stream has a compressed block offset but stream is not compressed"); } compressedSliceInput.setPosition(compressedBlockOffset); current = EMPTY_SLICE.getInput(); @@ -216,7 +246,7 @@ private void advance() boolean isUncompressed = (b0 & 0x01) == 1; int chunkLength = (b2 << 15) | (b1 << 7) | (b0 >>> 1); if (chunkLength < 0 || chunkLength > compressedSliceInput.remaining()) { - throw new OrcCorruptionException(String.format("The chunkLength (%s) must not be negative or greater than remaining size (%s)", chunkLength, compressedSliceInput.remaining())); + throw new OrcCorruptionException(orcDataSourceId, "The chunkLength (%s) must not be negative or greater than remaining size (%s)", chunkLength, compressedSliceInput.remaining()); } Slice chunk = compressedSliceInput.readSlice(chunkLength); @@ -256,7 +286,7 @@ public byte[] grow(int size) public String toString() { return toStringHelper(this) - .add("source", source) + .add("source", orcDataSourceId) .add("compressedOffset", compressedSliceInput.position()) .add("uncompressedOffset", current == null ? null : current.position()) .add("decompressor", decompressor.map(Object::toString).orElse("none")) diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/stream/OrcStreamUtils.java b/presto-orc/src/main/java/com/facebook/presto/orc/stream/OrcStreamUtils.java deleted file mode 100644 index 704131523c86b..0000000000000 --- a/presto-orc/src/main/java/com/facebook/presto/orc/stream/OrcStreamUtils.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.orc.stream; - -import com.facebook.presto.orc.OrcCorruptionException; - -import java.io.IOException; -import java.io.InputStream; - -final class OrcStreamUtils -{ - public static final int MIN_REPEAT_SIZE = 3; - - private OrcStreamUtils() - { - } - - public static void skipFully(InputStream input, long length) - throws IOException - { - while (length > 0) { - long result = input.skip(length); - if (result < 0) { - throw new OrcCorruptionException("Unexpected end of stream"); - } - length -= result; - } - } - - public static void readFully(InputStream input, byte[] buffer, int offset, int length) - throws IOException - { - while (offset < length) { - int result = input.read(buffer, offset, length - offset); - if (result < 0) { - throw new OrcCorruptionException("Unexpected end of stream"); - } - offset += result; - } - } -} diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/zstd/FseTableReader.java b/presto-orc/src/main/java/com/facebook/presto/orc/zstd/FseTableReader.java index 9a98b03d793e0..11977c4bdf603 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/zstd/FseTableReader.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/zstd/FseTableReader.java @@ -48,7 +48,7 @@ public int readFseTable(FiniteStateEntropy.Table table, Object inputBase, long i int remaining = (1 << tableLog) + 1; threshold = 1 << tableLog; - while ((remaining > 1) && symbolNumber < maxSymbol) { + while (remaining > 1 && symbolNumber <= maxSymbol) { if (previousIsZero) { int n0 = symbolNumber; while ((bitStream & 0xFFFF) == 0xFFFF) { diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java index e35300987ac08..739a1482ad88d 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/BenchmarkOrcDecimalReader.java @@ -117,7 +117,7 @@ private OrcRecordReader createRecordReader() { OrcDataSource dataSource = new FileOrcDataSource(dataPath, new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE)); MetadataReader metadataReader = new OrcMetadataReader(); - OrcReader orcReader = new OrcReader(dataSource, metadataReader, new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE)); + OrcReader orcReader = new OrcReader(dataSource, metadataReader, new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, MEGABYTE)); return orcReader.createRecordReader( ImmutableMap.of(0, DECIMAL_TYPE), OrcPredicate.TRUE, diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java b/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java index f41f104b8e0ee..e5e4fbf387ede 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/OrcTester.java @@ -14,6 +14,8 @@ package com.facebook.presto.orc; import com.facebook.hive.orc.OrcConf; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.orc.memory.AggregatedMemoryContext; import com.facebook.presto.orc.metadata.DwrfMetadataReader; import com.facebook.presto.orc.metadata.MetadataReader; @@ -32,6 +34,7 @@ import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.spi.type.VarbinaryType; import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.type.TypeRegistry; import com.google.common.base.Throwables; import com.google.common.collect.AbstractIterator; @@ -131,6 +134,10 @@ public class OrcTester public static final DateTimeZone HIVE_STORAGE_TIME_ZONE = DateTimeZone.forID("Asia/Katmandu"); private static final TypeManager TYPE_MANAGER = new TypeRegistry(); + static { + // associate TYPE_MANAGER with a function registry + new FunctionRegistry(TYPE_MANAGER, new BlockEncodingManager(TYPE_MANAGER), new FeaturesConfig()); + } public enum Format { @@ -219,7 +226,7 @@ public static OrcTester quickOrcTester() orcTester.listTestsEnabled = true; orcTester.nullTestsEnabled = true; orcTester.skipBatchTestsEnabled = true; - orcTester.formats = ImmutableSet.of(ORC_12); + orcTester.formats = ImmutableSet.of(ORC_12, DWRF); orcTester.compressions = ImmutableSet.of(ZLIB); return orcTester; } @@ -411,14 +418,14 @@ public void assertRoundTrip(Type type, List readValues) try (TempFile tempFile = new TempFile()) { writeOrcColumn(tempFile.getFile(), format, compression, type, readValues.iterator()); - assertFileContents(type, tempFile, readValues, false, false, metadataReader); + assertFileContents(type, tempFile, readValues, false, false, metadataReader, format); if (skipBatchTestsEnabled) { - assertFileContents(type, tempFile, readValues, true, false, metadataReader); + assertFileContents(type, tempFile, readValues, true, false, metadataReader, format); } if (skipStripeTestsEnabled) { - assertFileContents(type, tempFile, readValues, false, true, metadataReader); + assertFileContents(type, tempFile, readValues, false, true, metadataReader, format); } } } @@ -431,10 +438,11 @@ private static void assertFileContents( List expectedValues, boolean skipFirstBatch, boolean skipStripe, - MetadataReader metadataReader) + MetadataReader metadataReader, + Format format) throws IOException { - try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, metadataReader, createOrcPredicate(type, expectedValues), type)) { + try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, metadataReader, createOrcPredicate(type, expectedValues, format == DWRF), type)) { assertEquals(recordReader.getReaderPosition(), 0); assertEquals(recordReader.getFilePosition(), 0); @@ -551,7 +559,7 @@ static OrcRecordReader createCustomOrcRecordReader(TempFile tempFile, MetadataRe throws IOException { OrcDataSource orcDataSource = new FileOrcDataSource(tempFile.getFile(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); - OrcReader orcReader = new OrcReader(orcDataSource, metadataReader, new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(orcDataSource, metadataReader, new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); assertEquals(orcReader.getColumnNames(), ImmutableList.of("test")); assertEquals(orcReader.getFooter().getRowsInRowGroup(), 10_000); diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java b/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java index a917b4b2b4694..d9cf0a5d1b6f4 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/TestCachingOrcDataSource.java @@ -193,7 +193,7 @@ public void testIntegration() public void doIntegration(TestingOrcDataSource orcDataSource, DataSize maxMergeDistance, DataSize maxReadSize) throws IOException { - OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), maxMergeDistance, maxReadSize); + OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), maxMergeDistance, maxReadSize, new DataSize(1, Unit.MEGABYTE)); // 1 for reading file footer assertEquals(orcDataSource.getReadCount(), 1); List stripes = orcReader.getFooter().getStripes(); @@ -254,6 +254,12 @@ private static class FakeOrcDataSource { public static final FakeOrcDataSource INSTANCE = new FakeOrcDataSource(); + @Override + public OrcDataSourceId getId() + { + return new OrcDataSourceId("fake"); + } + @Override public long getReadBytes() { diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java b/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java index 839b618d7c0a5..46eefb7a3a96c 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/TestOrcReaderPositions.java @@ -183,7 +183,7 @@ public void testReadUserMetadata() createFileWithOnlyUserMetadata(tempFile.getFile(), metadata); OrcDataSource orcDataSource = new FileOrcDataSource(tempFile.getFile(), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE)); - OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(orcDataSource, new OrcMetadataReader(), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE), new DataSize(1, DataSize.Unit.MEGABYTE)); Footer footer = orcReader.getFooter(); Map readMetadata = Maps.transformValues(footer.getUserMetadata(), Slice::toStringAscii); assertEquals(readMetadata, metadata); diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcDataSource.java b/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcDataSource.java index 6862723fc78f1..ed01f470f946d 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcDataSource.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcDataSource.java @@ -35,6 +35,12 @@ public TestingOrcDataSource(OrcDataSource delegate) this.delegate = requireNonNull(delegate, "delegate is null"); } + @Override + public OrcDataSourceId getId() + { + return delegate.getId(); + } + public int getReadCount() { return readCount; diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcPredicate.java b/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcPredicate.java index 8d45385505b70..28e75d7169e2e 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcPredicate.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/TestingOrcPredicate.java @@ -62,53 +62,57 @@ private TestingOrcPredicate() { } - public static OrcPredicate createOrcPredicate(Type type, Iterable values) + public static OrcPredicate createOrcPredicate(Type type, Iterable values, boolean noFileStats) { List expectedValues = newArrayList(values); if (BOOLEAN.equals(type)) { - return new BooleanOrcPredicate(expectedValues); + return new BooleanOrcPredicate(expectedValues, noFileStats); } if (TINYINT.equals(type) || SMALLINT.equals(type) || INTEGER.equals(type) || BIGINT.equals(type)) { return new LongOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : ((Number) value).longValue()) - .collect(toList())); + .collect(toList()), + noFileStats); } if (TIMESTAMP.equals(type)) { return new LongOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : ((SqlTimestamp) value).getMillisUtc()) - .collect(toList())); + .collect(toList()), + noFileStats); } if (DATE.equals(type)) { return new DateOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : (long) ((SqlDate) value).getDays()) - .collect(toList())); + .collect(toList()), + noFileStats); } if (REAL.equals(type) || DOUBLE.equals(type)) { return new DoubleOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : ((Number) value).doubleValue()) - .collect(toList())); + .collect(toList()), + noFileStats); } if (type instanceof VarbinaryType) { // binary does not have stats - return new BasicOrcPredicate<>(expectedValues, Object.class); + return new BasicOrcPredicate<>(expectedValues, Object.class, noFileStats); } if (type instanceof VarcharType) { - return new StringOrcPredicate(expectedValues); + return new StringOrcPredicate(expectedValues, noFileStats); } if (type instanceof CharType) { - return new CharOrcPredicate(expectedValues); + return new CharOrcPredicate(expectedValues, noFileStats); } if (type instanceof DecimalType) { - return new DecimalOrcPredicate(expectedValues); + return new DecimalOrcPredicate(expectedValues, noFileStats); } String baseType = type.getTypeSignature().getBase(); if (ARRAY.equals(baseType) || MAP.equals(baseType) || ROW.equals(baseType)) { - return new BasicOrcPredicate<>(expectedValues, Object.class); + return new BasicOrcPredicate<>(expectedValues, Object.class, noFileStats); } throw new IllegalArgumentException("Unsupported type " + type); } @@ -117,21 +121,29 @@ public static class BasicOrcPredicate implements OrcPredicate { private final List expectedValues; + private final boolean noFileStats; - public BasicOrcPredicate(Iterable expectedValues, Class type) + public BasicOrcPredicate(Iterable expectedValues, Class type, boolean noFileStats) { List values = new ArrayList<>(); for (Object expectedValue : expectedValues) { values.add(type.cast(expectedValue)); } this.expectedValues = Collections.unmodifiableList(values); + this.noFileStats = noFileStats; } @Override public boolean matches(long numberOfRows, Map statisticsByColumnIndex) { ColumnStatistics columnStatistics = statisticsByColumnIndex.get(0); - assertTrue(columnStatistics.hasNumberOfValues()); + + // todo enable file stats when DWRF team verifies that the stats are correct + // assertTrue(columnStatistics.hasNumberOfValues()); + if (noFileStats && numberOfRows == expectedValues.size()) { + assertNull(columnStatistics); + return true; + } if (numberOfRows == expectedValues.size()) { // whole file @@ -181,9 +193,9 @@ protected boolean chunkMatchesStats(List chunk, ColumnStatistics columnStatis public static class BooleanOrcPredicate extends BasicOrcPredicate { - public BooleanOrcPredicate(Iterable expectedValues) + public BooleanOrcPredicate(Iterable expectedValues, boolean noFileStats) { - super(expectedValues, Boolean.class); + super(expectedValues, Boolean.class, noFileStats); } @Override @@ -212,9 +224,9 @@ protected boolean chunkMatchesStats(List chunk, ColumnStatistics column public static class DoubleOrcPredicate extends BasicOrcPredicate { - public DoubleOrcPredicate(Iterable expectedValues) + public DoubleOrcPredicate(Iterable expectedValues, boolean noFileStats) { - super(expectedValues, Double.class); + super(expectedValues, Double.class, noFileStats); } @Override @@ -249,18 +261,18 @@ protected boolean chunkMatchesStats(List chunk, ColumnStatistics columnS private static class DecimalOrcPredicate extends BasicOrcPredicate { - public DecimalOrcPredicate(Iterable expectedValues) + public DecimalOrcPredicate(Iterable expectedValues, boolean noFileStats) { - super(expectedValues, SqlDecimal.class); + super(expectedValues, SqlDecimal.class, noFileStats); } } public static class LongOrcPredicate extends BasicOrcPredicate { - public LongOrcPredicate(Iterable expectedValues) + public LongOrcPredicate(Iterable expectedValues, boolean noFileStats) { - super(expectedValues, Long.class); + super(expectedValues, Long.class, noFileStats); } @Override @@ -296,9 +308,9 @@ protected boolean chunkMatchesStats(List chunk, ColumnStatistics columnSta public static class StringOrcPredicate extends BasicOrcPredicate { - public StringOrcPredicate(Iterable expectedValues) + public StringOrcPredicate(Iterable expectedValues, boolean noFileStats) { - super(expectedValues, String.class); + super(expectedValues, String.class, noFileStats); } @Override @@ -341,9 +353,9 @@ protected boolean chunkMatchesStats(List chunk, ColumnStatistics columnS public static class CharOrcPredicate extends BasicOrcPredicate { - public CharOrcPredicate(Iterable expectedValues) + public CharOrcPredicate(Iterable expectedValues, boolean noFileStats) { - super(expectedValues, String.class); + super(expectedValues, String.class, noFileStats); } @Override @@ -386,9 +398,9 @@ protected boolean chunkMatchesStats(List chunk, ColumnStatistics columnS public static class DateOrcPredicate extends BasicOrcPredicate { - public DateOrcPredicate(Iterable expectedValues) + public DateOrcPredicate(Iterable expectedValues, boolean noFileStats) { - super(expectedValues, Long.class); + super(expectedValues, Long.class, noFileStats); } @Override diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/stream/TestDecimalStream.java b/presto-orc/src/test/java/com/facebook/presto/orc/stream/TestDecimalStream.java index a52925377f19c..378e129f9c788 100644 --- a/presto-orc/src/test/java/com/facebook/presto/orc/stream/TestDecimalStream.java +++ b/presto-orc/src/test/java/com/facebook/presto/orc/stream/TestDecimalStream.java @@ -14,6 +14,7 @@ package com.facebook.presto.orc.stream; import com.facebook.presto.orc.OrcCorruptionException; +import com.facebook.presto.orc.OrcDataSourceId; import com.facebook.presto.orc.memory.AggregatedMemoryContext; import io.airlift.slice.BasicSliceInput; import io.airlift.slice.Slice; @@ -147,7 +148,7 @@ private static OrcInputStream decimalInputStream(BigInteger value) private static OrcInputStream orcInputStreamFor(String source, byte[] bytes) { - return new OrcInputStream(source, new BasicSliceInput(Slices.wrappedBuffer(bytes)), Optional.empty(), new AggregatedMemoryContext()); + return new OrcInputStream(new OrcDataSourceId(source), new BasicSliceInput(Slices.wrappedBuffer(bytes)), Optional.empty(), new AggregatedMemoryContext()); } // copied from org.apache.hadoop.hive.ql.io.orc.SerializationUtils.java diff --git a/presto-parser/pom.xml b/presto-parser/pom.xml index f36cd9eca6576..489a2d036a17e 100644 --- a/presto-parser/pom.xml +++ b/presto-parser/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-parser diff --git a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 index df57e9f6cf01a..e1886db84f512 100644 --- a/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 @@ -71,6 +71,8 @@ statement | SHOW SCHEMAS ((FROM | IN) identifier)? (LIKE pattern=string)? #showSchemas | SHOW CATALOGS (LIKE pattern=string)? #showCatalogs | SHOW COLUMNS (FROM | IN) qualifiedName #showColumns + | SHOW STATS (FOR | ON) qualifiedName #showStats + | SHOW STATS FOR '(' querySpecification ')' #showStatsForQuery | DESCRIBE qualifiedName #showColumns | DESC qualifiedName #showColumns | SHOW FUNCTIONS #showFunctions @@ -217,7 +219,6 @@ sampledRelation sampleType : BERNOULLI | SYSTEM - | POISSONIZED ; aliasedRelation @@ -232,6 +233,7 @@ relationPrimary : qualifiedName #tableName | '(' query ')' #subqueryRelation | UNNEST '(' expression (',' expression)* ')' (WITH ORDINALITY)? #unnest + | LATERAL '(' query ')' #lateral | '(' relation ')' #parenthesizedRelation ; @@ -310,6 +312,7 @@ primaryExpression | NORMALIZE '(' valueExpression (',' normalForm)? ')' #normalize | EXTRACT '(' identifier FROM valueExpression ')' #extract | '(' expression ')' #parenthesizedExpression + | GROUPING '(' (qualifiedName (',' qualifiedName)*)? ')' #groupingOperation ; string @@ -342,6 +345,10 @@ intervalField : YEAR | MONTH | DAY | HOUR | MINUTE | SECOND ; +normalForm + : NFD | NFC | NFKD | NFKC + ; + type : type ARRAY | ARRAY '<' type '>' @@ -440,217 +447,211 @@ number ; nonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | SCHEMAS | CATALOGS | SESSION - | ADD - | FILTER - | AT - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY - | TINYINT | SMALLINT | INTEGER | DATE | TIME | TIMESTAMP | INTERVAL | ZONE - | YEAR | MONTH | DAY | HOUR | MINUTE | SECOND - | EXPLAIN | ANALYZE | FORMAT | TYPE | TEXT | GRAPHVIZ | LOGICAL | DISTRIBUTED | VALIDATE - | TABLESAMPLE | SYSTEM | BERNOULLI | POISSONIZED | USE | TO - | SET | RESET - | VIEW | REPLACE - | IF | NULLIF | COALESCE - | normalForm - | POSITION - | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL - | SERIALIZABLE | REPEATABLE | COMMITTED | UNCOMMITTED | READ | WRITE | ONLY - | COMMENT - | CALL - | GRANT | REVOKE | PRIVILEGES | PUBLIC | OPTION | GRANTS - | SUBSTRING - | SCHEMA | CASCADE | RESTRICT - | INPUT | OUTPUT - | INCLUDING | EXCLUDING | PROPERTIES - | ALL | SOME | ANY - ; - -normalForm - : NFD | NFC | NFKD | NFKC + // IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved + : ADD | ALL | ANALYZE | ANY | ARRAY | ASC | AT + | BERNOULLI + | CALL | CASCADE | CATALOGS | COALESCE | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CURRENT + | DATA | DATE | DAY | DESC | DISTRIBUTED + | EXCLUDING | EXPLAIN + | FILTER | FIRST | FOLLOWING | FORMAT | FUNCTIONS + | GRANT | GRANTS | GRAPHVIZ + | HOUR + | IF | INCLUDING | INPUT | INTEGER | INTERVAL | ISOLATION + | LAST | LATERAL | LEVEL | LIMIT | LOGICAL + | MAP | MINUTE | MONTH + | NFC | NFD | NFKC | NFKD | NO | NULLIF | NULLS + | ONLY | OPTION | ORDINALITY | OUTPUT | OVER + | PARTITION | PARTITIONS | POSITION | PRECEDING | PRIVILEGES | PROPERTIES | PUBLIC + | RANGE | READ | RENAME | REPEATABLE | REPLACE | RESET | RESTRICT | REVOKE | ROLLBACK | ROW | ROWS + | SCHEMA | SCHEMAS | SECOND | SERIALIZABLE | SESSION | SET | SETS + | SHOW | SMALLINT | SOME | START | STATS | SUBSTRING | SYSTEM + | TABLES | TABLESAMPLE | TEXT | TIME | TIMESTAMP | TINYINT | TO | TRANSACTION | TRY_CAST | TYPE + | UNBOUNDED | UNCOMMITTED | USE + | VALIDATE | VIEW + | WORK | WRITE + | YEAR + | ZONE ; -SELECT: 'SELECT'; -FROM: 'FROM'; ADD: 'ADD'; -AS: 'AS'; ALL: 'ALL'; -SOME: 'SOME'; +ALTER: 'ALTER'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; ANY: 'ANY'; -DISTINCT: 'DISTINCT'; -WHERE: 'WHERE'; -GROUP: 'GROUP'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +BERNOULLI: 'BERNOULLI'; +BETWEEN: 'BETWEEN'; BY: 'BY'; -GROUPING: 'GROUPING'; -SETS: 'SETS'; +CALL: 'CALL'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CATALOGS: 'CATALOGS'; +COALESCE: 'COALESCE'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMMITTED: 'COMMITTED'; +CONSTRAINT: 'CONSTRAINT'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; CUBE: 'CUBE'; -ROLLUP: 'ROLLUP'; -ORDER: 'ORDER'; -HAVING: 'HAVING'; -LIMIT: 'LIMIT'; -AT: 'AT'; -OR: 'OR'; -AND: 'AND'; -IN: 'IN'; -NOT: 'NOT'; -NO: 'NO'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +DATA: 'DATA'; +DATE: 'DATE'; +DAY: 'DAY'; +DEALLOCATE: 'DEALLOCATE'; +DELETE: 'DELETE'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DISTINCT: 'DISTINCT'; +DISTRIBUTED: 'DISTRIBUTED'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +EXCEPT: 'EXCEPT'; +EXCLUDING: 'EXCLUDING'; +EXECUTE: 'EXECUTE'; EXISTS: 'EXISTS'; -BETWEEN: 'BETWEEN'; -LIKE: 'LIKE'; -IS: 'IS'; -NULL: 'NULL'; -TRUE: 'TRUE'; +EXPLAIN: 'EXPLAIN'; +EXTRACT: 'EXTRACT'; FALSE: 'FALSE'; -NULLS: 'NULLS'; +FILTER: 'FILTER'; FIRST: 'FIRST'; -LAST: 'LAST'; -ESCAPE: 'ESCAPE'; -ASC: 'ASC'; -DESC: 'DESC'; -SUBSTRING: 'SUBSTRING'; -POSITION: 'POSITION'; +FOLLOWING: 'FOLLOWING'; FOR: 'FOR'; -TINYINT: 'TINYINT'; -SMALLINT: 'SMALLINT'; +FORMAT: 'FORMAT'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTIONS: 'FUNCTIONS'; +GRANT: 'GRANT'; +GRANTS: 'GRANTS'; +GRAPHVIZ: 'GRAPHVIZ'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +HOUR: 'HOUR'; +IF: 'IF'; +IN: 'IN'; +INCLUDING: 'INCLUDING'; +INNER: 'INNER'; +INPUT: 'INPUT'; +INSERT: 'INSERT'; INTEGER: 'INTEGER'; -DATE: 'DATE'; -TIME: 'TIME'; -TIMESTAMP: 'TIMESTAMP'; +INTERSECT: 'INTERSECT'; INTERVAL: 'INTERVAL'; -YEAR: 'YEAR'; -MONTH: 'MONTH'; -DAY: 'DAY'; -HOUR: 'HOUR'; -MINUTE: 'MINUTE'; -SECOND: 'SECOND'; -ZONE: 'ZONE'; -CURRENT_DATE: 'CURRENT_DATE'; -CURRENT_TIME: 'CURRENT_TIME'; -CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; -LOCALTIME: 'LOCALTIME'; -LOCALTIMESTAMP: 'LOCALTIMESTAMP'; -EXTRACT: 'EXTRACT'; -CASE: 'CASE'; -WHEN: 'WHEN'; -THEN: 'THEN'; -ELSE: 'ELSE'; -END: 'END'; +INTO: 'INTO'; +IS: 'IS'; +ISOLATION: 'ISOLATION'; JOIN: 'JOIN'; -CROSS: 'CROSS'; -OUTER: 'OUTER'; -INNER: 'INNER'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; LEFT: 'LEFT'; -RIGHT: 'RIGHT'; -FULL: 'FULL'; +LEVEL: 'LEVEL'; +LIKE: 'LIKE'; +LIMIT: 'LIMIT'; +LOCALTIME: 'LOCALTIME'; +LOCALTIMESTAMP: 'LOCALTIMESTAMP'; +LOGICAL: 'LOGICAL'; +MAP: 'MAP'; +MINUTE: 'MINUTE'; +MONTH: 'MONTH'; NATURAL: 'NATURAL'; -USING: 'USING'; +NFC : 'NFC'; +NFD : 'NFD'; +NFKC : 'NFKC'; +NFKD : 'NFKD'; +NO: 'NO'; +NORMALIZE: 'NORMALIZE'; +NOT: 'NOT'; +NULL: 'NULL'; +NULLIF: 'NULLIF'; +NULLS: 'NULLS'; ON: 'ON'; -FILTER: 'FILTER'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OR: 'OR'; +ORDER: 'ORDER'; +ORDINALITY: 'ORDINALITY'; +OUTER: 'OUTER'; +OUTPUT: 'OUTPUT'; OVER: 'OVER'; PARTITION: 'PARTITION'; -RANGE: 'RANGE'; -ROWS: 'ROWS'; -UNBOUNDED: 'UNBOUNDED'; +PARTITIONS: 'PARTITIONS'; +POSITION: 'POSITION'; PRECEDING: 'PRECEDING'; -FOLLOWING: 'FOLLOWING'; -CURRENT: 'CURRENT'; -ROW: 'ROW'; -WITH: 'WITH'; -RECURSIVE: 'RECURSIVE'; -VALUES: 'VALUES'; -CREATE: 'CREATE'; -SCHEMA: 'SCHEMA'; -TABLE: 'TABLE'; -COMMENT: 'COMMENT'; -VIEW: 'VIEW'; -REPLACE: 'REPLACE'; -INSERT: 'INSERT'; -DELETE: 'DELETE'; -INTO: 'INTO'; -CONSTRAINT: 'CONSTRAINT'; -DESCRIBE: 'DESCRIBE'; -GRANT: 'GRANT'; -REVOKE: 'REVOKE'; +PREPARE: 'PREPARE'; PRIVILEGES: 'PRIVILEGES'; +PROPERTIES: 'PROPERTIES'; PUBLIC: 'PUBLIC'; -OPTION: 'OPTION'; -GRANTS: 'GRANTS'; -EXPLAIN: 'EXPLAIN'; -ANALYZE: 'ANALYZE'; -FORMAT: 'FORMAT'; -TYPE: 'TYPE'; -TEXT: 'TEXT'; -GRAPHVIZ: 'GRAPHVIZ'; -LOGICAL: 'LOGICAL'; -DISTRIBUTED: 'DISTRIBUTED'; -VALIDATE: 'VALIDATE'; -CAST: 'CAST'; -TRY_CAST: 'TRY_CAST'; -SHOW: 'SHOW'; -TABLES: 'TABLES'; -SCHEMAS: 'SCHEMAS'; -CATALOGS: 'CATALOGS'; -COLUMNS: 'COLUMNS'; -COLUMN: 'COLUMN'; -USE: 'USE'; -PARTITIONS: 'PARTITIONS'; -FUNCTIONS: 'FUNCTIONS'; -DROP: 'DROP'; -UNION: 'UNION'; -EXCEPT: 'EXCEPT'; -INTERSECT: 'INTERSECT'; -TO: 'TO'; -SYSTEM: 'SYSTEM'; -BERNOULLI: 'BERNOULLI'; -POISSONIZED: 'POISSONIZED'; -TABLESAMPLE: 'TABLESAMPLE'; -ALTER: 'ALTER'; +RANGE: 'RANGE'; +READ: 'READ'; +RECURSIVE: 'RECURSIVE'; RENAME: 'RENAME'; -UNNEST: 'UNNEST'; -ORDINALITY: 'ORDINALITY'; -ARRAY: 'ARRAY'; -MAP: 'MAP'; -SET: 'SET'; +REPEATABLE: 'REPEATABLE'; +REPLACE: 'REPLACE'; RESET: 'RESET'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SCHEMA: 'SCHEMA'; +SCHEMAS: 'SCHEMAS'; +SECOND: 'SECOND'; +SELECT: 'SELECT'; +SERIALIZABLE: 'SERIALIZABLE'; SESSION: 'SESSION'; -DATA: 'DATA'; +SET: 'SET'; +SETS: 'SETS'; +SHOW: 'SHOW'; +SMALLINT: 'SMALLINT'; +SOME: 'SOME'; START: 'START'; +STATS: 'STATS'; +SUBSTRING: 'SUBSTRING'; +SYSTEM: 'SYSTEM'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TEXT: 'TEXT'; +THEN: 'THEN'; +TIME: 'TIME'; +TIMESTAMP: 'TIMESTAMP'; +TINYINT: 'TINYINT'; +TO: 'TO'; TRANSACTION: 'TRANSACTION'; -COMMIT: 'COMMIT'; -ROLLBACK: 'ROLLBACK'; -WORK: 'WORK'; -ISOLATION: 'ISOLATION'; -LEVEL: 'LEVEL'; -SERIALIZABLE: 'SERIALIZABLE'; -REPEATABLE: 'REPEATABLE'; -COMMITTED: 'COMMITTED'; +TRUE: 'TRUE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UESCAPE: 'UESCAPE'; +UNBOUNDED: 'UNBOUNDED'; UNCOMMITTED: 'UNCOMMITTED'; -READ: 'READ'; +UNION: 'UNION'; +UNNEST: 'UNNEST'; +USE: 'USE'; +USING: 'USING'; +VALIDATE: 'VALIDATE'; +VALUES: 'VALUES'; +VIEW: 'VIEW'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WITH: 'WITH'; +WORK: 'WORK'; WRITE: 'WRITE'; -ONLY: 'ONLY'; -CALL: 'CALL'; -PREPARE: 'PREPARE'; -DEALLOCATE: 'DEALLOCATE'; -EXECUTE: 'EXECUTE'; -INPUT: 'INPUT'; -OUTPUT: 'OUTPUT'; -CASCADE: 'CASCADE'; -RESTRICT: 'RESTRICT'; -INCLUDING: 'INCLUDING'; -EXCLUDING: 'EXCLUDING'; -PROPERTIES: 'PROPERTIES'; -UESCAPE: 'UESCAPE'; - -NORMALIZE: 'NORMALIZE'; -NFD : 'NFD'; -NFC : 'NFC'; -NFKD : 'NFKD'; -NFKC : 'NFKC'; - -IF: 'IF'; -NULLIF: 'NULLIF'; -COALESCE: 'COALESCE'; +YEAR: 'YEAR'; +ZONE: 'ZONE'; EQ : '='; NEQ : '<>' | '!='; diff --git a/presto-parser/src/main/antlr4/com/facebook/presto/type/TypeCalculation.g4 b/presto-parser/src/main/antlr4/com/facebook/presto/type/TypeCalculation.g4 index de7817586e858..d86a9ae1c3c38 100644 --- a/presto-parser/src/main/antlr4/com/facebook/presto/type/TypeCalculation.g4 +++ b/presto-parser/src/main/antlr4/com/facebook/presto/type/TypeCalculation.g4 @@ -15,8 +15,10 @@ //TODO: consider using the SQL grammar for this grammar TypeCalculation; +// workaround for: +// https://github.com/antlr/antlr4/issues/118 typeCalculation - : expression + : expression EOF ; expression diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java index ce9c8019e757c..4eecb84f5499f 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/ExpressionFormatter.java @@ -40,6 +40,7 @@ import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.GroupingElement; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.GroupingSets; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IfExpression; @@ -367,9 +368,14 @@ protected String visitLambdaExpression(LambdaExpression node, Void context) @Override protected String visitBindExpression(BindExpression node, Void context) { - return "\"$INTERNAL$BIND\"(" + - process(node.getValue(), context) + ", " + - process(node.getFunction(), context) + ")"; + StringBuilder builder = new StringBuilder(); + + builder.append("\"$INTERNAL$BIND\"("); + for (Expression value : node.getValues()) { + builder.append(process(value, context) + ", "); + } + builder.append(process(node.getFunction(), context) + ")"); + return builder.toString(); } @Override @@ -635,6 +641,11 @@ protected String visitQuantifiedComparisonExpression(QuantifiedComparisonExpress .toString(); } + public String visitGroupingOperation(GroupingOperation node, Void context) + { + return "GROUPING (" + joinExpressions(node.getGroupingColumns()) + ")"; + } + private String formatBinaryExpression(String operator, Expression left, Expression right) { return '(' + process(left, null) + ' ' + operator + ' ' + process(right, null) + ')'; diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/ReservedIdentifiers.java b/presto-parser/src/main/java/com/facebook/presto/sql/ReservedIdentifiers.java new file mode 100644 index 0000000000000..76add99963acb --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/ReservedIdentifiers.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql; + +import com.facebook.presto.sql.parser.ParsingException; +import com.facebook.presto.sql.parser.SqlBaseLexer; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Identifier; +import com.google.common.collect.ImmutableSet; +import org.antlr.v4.runtime.Vocabulary; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; + +public final class ReservedIdentifiers +{ + private static final Pattern IDENTIFIER = Pattern.compile("'([A-Z_]+)'"); + private static final Pattern TABLE_ROW = Pattern.compile("``([A-Z_]+)``.*"); + private static final String TABLE_PREFIX = "============================== "; + + private static final SqlParser PARSER = new SqlParser(); + + private ReservedIdentifiers() {} + + @SuppressWarnings("CallToPrintStackTrace") + public static void main(String[] args) + throws IOException + { + if ((args.length == 2) && args[0].equals("validateDocs")) { + try { + validateDocs(Paths.get(args[1])); + } + catch (Throwable t) { + t.printStackTrace(); + System.exit(100); + } + } + else { + for (String name : reservedIdentifiers()) { + System.out.println(name); + } + } + } + + private static void validateDocs(Path path) + throws IOException + { + System.out.println("Validating " + path); + List lines = Files.readAllLines(path); + + if (lines.stream().filter(s -> s.startsWith(TABLE_PREFIX)).count() != 3) { + throw new RuntimeException("Failed to find exactly one table"); + } + + Iterator iterator = lines.iterator(); + + // find table and skip header + while (!iterator.next().startsWith(TABLE_PREFIX)) { + // skip + } + if (iterator.next().startsWith(TABLE_PREFIX)) { + throw new RuntimeException("Expected to find a header line"); + } + if (!iterator.next().startsWith(TABLE_PREFIX)) { + throw new RuntimeException("Found multiple header lines"); + } + + Set reserved = reservedIdentifiers(); + Set found = new HashSet<>(); + while (true) { + String line = iterator.next(); + if (line.startsWith(TABLE_PREFIX)) { + break; + } + + Matcher matcher = TABLE_ROW.matcher(line); + if (!matcher.matches()) { + throw new RuntimeException("Invalid table line: " + line); + } + String name = matcher.group(1); + + if (!reserved.contains(name)) { + throw new RuntimeException("Documented identifier is not reserved: " + name); + } + if (!found.add(name)) { + throw new RuntimeException("Duplicate documented identifier: " + name); + } + } + + for (String name : reserved) { + if (!found.contains(name)) { + throw new RuntimeException("Reserved identifier is not documented: " + name); + } + } + + System.out.println(format("Validated %s reserved identifiers", reserved.size())); + } + + public static Set reservedIdentifiers() + { + return possibleIdentifiers().stream() + .filter(ReservedIdentifiers::reserved) + .sorted() + .collect(toImmutableSet()); + } + + private static Set possibleIdentifiers() + { + ImmutableSet.Builder names = ImmutableSet.builder(); + Vocabulary vocabulary = SqlBaseLexer.VOCABULARY; + for (int i = 0; i <= vocabulary.getMaxTokenType(); i++) { + String name = nullToEmpty(vocabulary.getLiteralName(i)); + Matcher matcher = IDENTIFIER.matcher(name); + if (matcher.matches()) { + names.add(matcher.group(1)); + } + } + return names.build(); + } + + private static boolean reserved(String name) + { + try { + return !(PARSER.createExpression(name) instanceof Identifier); + } + catch (ParsingException ignored) { + return true; + } + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java index 67b3f59e6dbf9..6aeaf74f3f676 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/SqlFormatter.java @@ -47,6 +47,7 @@ import com.facebook.presto.sql.tree.JoinCriteria; import com.facebook.presto.sql.tree.JoinOn; import com.facebook.presto.sql.tree.JoinUsing; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.NaturalJoin; import com.facebook.presto.sql.tree.Node; @@ -75,6 +76,7 @@ import com.facebook.presto.sql.tree.ShowPartitions; import com.facebook.presto.sql.tree.ShowSchemas; import com.facebook.presto.sql.tree.ShowSession; +import com.facebook.presto.sql.tree.ShowStats; import com.facebook.presto.sql.tree.ShowTables; import com.facebook.presto.sql.tree.SingleColumn; import com.facebook.presto.sql.tree.StartTransaction; @@ -154,6 +156,15 @@ protected Void visitUnnest(Unnest node, Integer indent) return null; } + @Override + protected Void visitLateral(Lateral node, Integer indent) + { + append(indent, "LATERAL ("); + process(node.getQuery(), indent + 1); + append(indent, ")"); + return null; + } + @Override protected Void visitPrepare(Prepare node, Integer indent) { @@ -637,6 +648,15 @@ protected Void visitShowColumns(ShowColumns node, Integer context) return null; } + @Override + protected Void visitShowStats(ShowStats node, Integer context) + { + builder.append("SHOW STATS FOR "); + process(node.getRelation(), 0); + builder.append(""); + return null; + } + @Override protected Void visitShowPartitions(ShowPartitions node, Integer context) { diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index 8bcd9356022c2..ad58155ab9fee 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -66,6 +66,7 @@ import com.facebook.presto.sql.tree.Grant; import com.facebook.presto.sql.tree.GroupBy; import com.facebook.presto.sql.tree.GroupingElement; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.GroupingSets; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IfExpression; @@ -83,6 +84,7 @@ import com.facebook.presto.sql.tree.JoinUsing; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; @@ -123,6 +125,7 @@ import com.facebook.presto.sql.tree.ShowPartitions; import com.facebook.presto.sql.tree.ShowSchemas; import com.facebook.presto.sql.tree.ShowSession; +import com.facebook.presto.sql.tree.ShowStats; import com.facebook.presto.sql.tree.ShowTables; import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.SimpleGroupBy; @@ -164,6 +167,7 @@ import java.util.Optional; import java.util.stream.Collectors; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -678,6 +682,20 @@ public Node visitShowColumns(SqlBaseParser.ShowColumnsContext context) return new ShowColumns(getLocation(context), getQualifiedName(context.qualifiedName())); } + @Override + public Node visitShowStats(SqlBaseParser.ShowStatsContext context) + { + return new ShowStats(Optional.of(getLocation(context)), new Table(getQualifiedName(context.qualifiedName()))); + } + + @Override + public Node visitShowStatsForQuery(SqlBaseParser.ShowStatsForQueryContext context) + { + QuerySpecification specification = (QuerySpecification) visitQuerySpecification(context.querySpecification()); + Query query = new Query(Optional.empty(), specification, Optional.empty(), Optional.empty()); + return new ShowStats(Optional.of(getLocation(context)), new TableSubquery(query)); + } + @Override public Node visitShowPartitions(SqlBaseParser.ShowPartitionsContext context) { @@ -895,6 +913,12 @@ public Node visitUnnest(SqlBaseParser.UnnestContext context) return new Unnest(getLocation(context), visit(context.expression(), Expression.class), context.ORDINALITY() != null); } + @Override + public Node visitLateral(SqlBaseParser.LateralContext context) + { + return new Lateral(getLocation(context), (Query) visit(context.query())); + } + @Override public Node visitParenthesizedRelation(SqlBaseParser.ParenthesizedRelationContext context) { @@ -1262,14 +1286,20 @@ public Node visitFunctionCall(SqlBaseParser.FunctionCallContext context) return new TryExpression(getLocation(context), (Expression) visit(getOnlyElement(context.expression()))); } if (name.toString().equalsIgnoreCase("$internal$bind")) { - check(context.expression().size() == 2, "The '$internal$bind' function must have exactly two arguments", context); + check(context.expression().size() >= 1, "The '$internal$bind' function must have at least one arguments", context); check(!window.isPresent(), "OVER clause not valid for '$internal$bind' function", context); check(!distinct, "DISTINCT not valid for '$internal$bind' function", context); + int numValues = context.expression().size() - 1; + List arguments = context.expression().stream() + .map(this::visit) + .map(Expression.class::cast) + .collect(toImmutableList()); + return new BindExpression( getLocation(context), - (Expression) visit(context.expression(0)), - (Expression) visit(context.expression(1))); + arguments.subList(0, numValues), + arguments.get(numValues)); } return new FunctionCall( @@ -1291,7 +1321,7 @@ public Node visitLambda(SqlBaseParser.LambdaContext context) Expression body = (Expression) visit(context.expression()); - return new LambdaExpression(arguments, body); + return new LambdaExpression(getLocation(context), arguments, body); } @Override @@ -1377,6 +1407,16 @@ public Node visitCurrentRowBound(SqlBaseParser.CurrentRowBoundContext context) return new FrameBound(getLocation(context), FrameBound.Type.CURRENT_ROW); } + @Override + public Node visitGroupingOperation(SqlBaseParser.GroupingOperationContext context) + { + List arguments = context.qualifiedName().stream() + .map(AstBuilder::getQualifiedName) + .collect(toList()); + + return new GroupingOperation(Optional.of(getLocation(context)), arguments); + } + // ************** literals ************** @Override diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java index 48cfa1e664876..6f68366281bf8 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/SqlParser.java @@ -27,10 +27,13 @@ import org.antlr.v4.runtime.atn.PredictionMode; import org.antlr.v4.runtime.misc.Pair; import org.antlr.v4.runtime.misc.ParseCancellationException; +import org.antlr.v4.runtime.tree.TerminalNode; import javax.inject.Inject; +import java.util.Arrays; import java.util.EnumSet; +import java.util.List; import java.util.function.Function; import static java.util.Objects.requireNonNull; @@ -77,7 +80,7 @@ private Node invokeParser(String name, String sql, Function ruleNames; + + public PostProcessor(List ruleNames) + { + this.ruleNames = ruleNames; + } + @Override public void exitUnquotedIdentifier(SqlBaseParser.UnquotedIdentifierContext context) { @@ -162,6 +172,13 @@ public void exitQuotedIdentifier(SqlBaseParser.QuotedIdentifierContext context) @Override public void exitNonReserved(SqlBaseParser.NonReservedContext context) { + // we can't modify the tree during rule enter/exit event handling unless we're dealing with a terminal. + // Otherwise, ANTLR gets confused an fires spurious notifications. + if (!(context.getChild(0) instanceof TerminalNode)) { + int rule = ((ParserRuleContext) context.getChild(0)).getRuleIndex(); + throw new AssertionError("nonReserved can only contain tokens. Found nested rule: " + ruleNames.get(rule)); + } + // replace nonReserved words with IDENT tokens context.getParent().removeLastChild(); diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java index bb4ef1eca444c..7e3dc4d21b920 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/AstVisitor.java @@ -142,6 +142,11 @@ protected R visitShowColumns(ShowColumns node, C context) return visitStatement(node, context); } + protected R visitShowStats(ShowStats node, C context) + { + return visitStatement(node, context); + } + protected R visitShowPartitions(ShowPartitions node, C context) { return visitStatement(node, context); @@ -422,6 +427,11 @@ protected R visitUnnest(Unnest node, C context) return visitRelation(node, context); } + protected R visitLateral(Lateral node, C context) + { + return visitRelation(node, context); + } + protected R visitValues(Values node, C context) { return visitQueryBody(node, context); @@ -676,4 +686,9 @@ protected R visitBindExpression(BindExpression node, C context) { return visitExpression(node, context); } + + protected R visitGroupingOperation(GroupingOperation node, C context) + { + return visitExpression(node, context); + } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/BindExpression.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/BindExpression.java index 49b5fd4e58c2d..8ab8f4ed57793 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/BindExpression.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/BindExpression.java @@ -47,31 +47,31 @@ public class BindExpression extends Expression { - private final Expression value; + private final List values; // Function expression must be of function type. // It is not necessarily a lambda. For example, it can be another bind expression. private final Expression function; - public BindExpression(Expression value, Expression function) + public BindExpression(List values, Expression function) { - this(Optional.empty(), value, function); + this(Optional.empty(), values, function); } - public BindExpression(NodeLocation location, Expression value, Expression function) + public BindExpression(NodeLocation location, List values, Expression function) { - this(Optional.of(location), value, function); + this(Optional.of(location), values, function); } - private BindExpression(Optional location, Expression value, Expression function) + private BindExpression(Optional location, List values, Expression function) { super(location); - this.value = requireNonNull(value, "value is null"); + this.values = requireNonNull(values, "value is null"); this.function = requireNonNull(function, "function is null"); } - public Expression getValue() + public List getValues() { - return value; + return values; } public Expression getFunction() @@ -89,7 +89,7 @@ public R accept(AstVisitor visitor, C context) public List getChildren() { ImmutableList.Builder nodes = ImmutableList.builder(); - return nodes.add(value) + return nodes.addAll(values) .add(function) .build(); } @@ -104,13 +104,13 @@ public boolean equals(Object o) return false; } BindExpression that = (BindExpression) o; - return Objects.equals(value, that.value) && + return Objects.equals(values, that.values) && Objects.equals(function, that.function); } @Override public int hashCode() { - return Objects.hash(value, function); + return Objects.hash(values, function); } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java index 40902927fdbe4..02990e504de6c 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/DefaultTraversalVisitor.java @@ -181,6 +181,16 @@ protected R visitFunctionCall(FunctionCall node, C context) return null; } + @Override + protected R visitGroupingOperation(GroupingOperation node, C context) + { + for (Expression columnArgument : node.getGroupingColumns()) { + process(columnArgument, context); + } + + return null; + } + @Override protected R visitDereferenceExpression(DereferenceExpression node, C context) { @@ -282,7 +292,9 @@ protected R visitTryExpression(TryExpression node, C context) @Override protected R visitBindExpression(BindExpression node, C context) { - process(node.getValue(), context); + for (Expression value : node.getValues()) { + process(value, context); + } process(node.getFunction(), context); return null; @@ -594,4 +606,12 @@ protected R visitExists(ExistsPredicate node, C context) return null; } + + @Override + protected R visitLateral(Lateral node, C context) + { + process(node.getQuery(), context); + + return super.visitLateral(node, context); + } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java index b616d926ab473..3528ae74055de 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionRewriter.java @@ -204,4 +204,9 @@ public Expression rewriteQuantifiedComparison(QuantifiedComparisonExpression nod { return rewriteExpression(node, context, treeRewriter); } + + public Expression rewriteGroupingOperation(GroupingOperation node, C context, ExpressionTreeRewriter treeRewriter) + { + return rewriteExpression(node, context, treeRewriter); + } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java index c457514d79f71..45dec0feee17d 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ExpressionTreeRewriter.java @@ -20,6 +20,8 @@ import java.util.List; import java.util.Optional; +import static com.google.common.collect.ImmutableList.toImmutableList; + public final class ExpressionTreeRewriter { private final ExpressionRewriter rewriter; @@ -598,13 +600,14 @@ protected Expression visitBindExpression(BindExpression node, Context context } } - Expression value = rewrite(node.getValue(), context.get()); + List values = node.getValues().stream() + .map(value -> rewrite(value, context.get())) + .collect(toImmutableList()); Expression function = rewrite(node.getFunction(), context.get()); - if ((value != node.getValue()) || (function != node.getFunction())) { - return new BindExpression(value, function); + if (!sameElements(values, node.getValues()) || (function != node.getFunction())) { + return new BindExpression(values, function); } - return node; } @@ -855,6 +858,19 @@ protected Expression visitQuantifiedComparisonExpression(QuantifiedComparisonExp return node; } + + @Override + public Expression visitGroupingOperation(GroupingOperation node, Context context) + { + if (!context.isDefaultRewrite()) { + Expression result = rewriter.rewriteGroupingOperation(node, context.get(), ExpressionTreeRewriter.this); + if (result != null) { + return result; + } + } + + return node; + } } public static class Context diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/GroupingOperation.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/GroupingOperation.java new file mode 100644 index 0000000000000..a339f5112141b --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/GroupingOperation.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class GroupingOperation + extends Expression +{ + private final List groupingColumns; + + public GroupingOperation(Optional location, List groupingColumns) + { + super(location); + requireNonNull(groupingColumns); + checkArgument(!groupingColumns.isEmpty(), "grouping operation columns cannot be empty"); + this.groupingColumns = groupingColumns.stream() + .map(DereferenceExpression::from) + .collect(toImmutableList()); + } + + public List getGroupingColumns() + { + return groupingColumns; + } + + @Override + protected R accept(AstVisitor visitor, C context) + { + return visitor.visitGroupingOperation(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GroupingOperation other = (GroupingOperation) o; + return Objects.equals(groupingColumns, other.groupingColumns); + } + + @Override + public int hashCode() + { + return Objects.hash(groupingColumns); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/GroupingSets.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/GroupingSets.java index aedd605be72fc..4b8da732bcc65 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/tree/GroupingSets.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/GroupingSets.java @@ -24,6 +24,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; @@ -47,7 +48,7 @@ private GroupingSets(Optional location, List> super(location); requireNonNull(sets, "sets is null"); checkArgument(!sets.isEmpty(), "grouping sets cannot be empty"); - this.sets = ImmutableList.copyOf(sets.stream().map(ImmutableList::copyOf).collect(Collectors.toList())); + this.sets = sets.stream().map(ImmutableList::copyOf).collect(toImmutableList()); } public List> getSets() diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/Lateral.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Lateral.java new file mode 100644 index 0000000000000..2892040839977 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/Lateral.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public final class Lateral + extends Relation +{ + private final Query query; + + public Lateral(Query query) + { + this(Optional.empty(), query); + } + + public Lateral(NodeLocation location, Query query) + { + this(Optional.of(location), query); + } + + private Lateral(Optional location, Query query) + { + super(location); + this.query = requireNonNull(query, "query is null"); + } + + public Query getQuery() + { + return query; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitLateral(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(query); + } + + @Override + public String toString() + { + return "LATERAL(" + query + ")"; + } + + @Override + public int hashCode() + { + return Objects.hash(query); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Lateral other = (Lateral) obj; + return Objects.equals(this.query, other.query); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/NodeRef.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/NodeRef.java new file mode 100644 index 0000000000000..ca4dce0a1a1c5 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/NodeRef.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import static java.lang.String.format; +import static java.lang.System.identityHashCode; +import static java.util.Objects.requireNonNull; + +public final class NodeRef +{ + public static NodeRef of(T node) + { + return new NodeRef<>(node); + } + + private final T node; + + private NodeRef(T node) + { + this.node = requireNonNull(node, "node is null"); + } + + public T getNode() + { + return node; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NodeRef other = (NodeRef) o; + return node == other.node; + } + + @Override + public int hashCode() + { + return identityHashCode(node); + } + + @Override + public String toString() + { + return format( + "@%s: %s", + Integer.toHexString(identityHashCode(node)), + node); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/tree/ShowStats.java b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ShowStats.java new file mode 100644 index 0000000000000..a55706328da60 --- /dev/null +++ b/presto-parser/src/main/java/com/facebook/presto/sql/tree/ShowStats.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.tree; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public class ShowStats + extends Statement +{ + private final Relation relation; + + @VisibleForTesting + public ShowStats(Relation relation) + { + this(Optional.empty(), relation); + } + + public ShowStats(Optional location, Relation relation) + { + super(location); + this.relation = relation; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitShowStats(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(relation); + } + + @Override + public int hashCode() + { + return Objects.hash(relation); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + ShowStats o = (ShowStats) obj; + return Objects.equals(relation, o.relation); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("query", relation) + .toString(); + } + + public Relation getRelation() + { + return relation; + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java b/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java index 510febf77b8bc..b5b63cffa1662 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/util/AstUtils.java @@ -13,67 +13,30 @@ */ package com.facebook.presto.sql.util; -import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Node; +import com.google.common.collect.TreeTraverser; -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.Iterator; -import java.util.Spliterators; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Stream; -import java.util.stream.StreamSupport; + +import static com.google.common.collect.Iterables.unmodifiableIterable; +import static java.util.Objects.requireNonNull; public class AstUtils { public static boolean nodeContains(Node node, Node subNode) { - return new DefaultTraversalVisitor() - { - @Override - public Boolean process(Node node, AtomicBoolean findResultHolder) - { - if (!findResultHolder.get()) { - if (node == subNode) { - findResultHolder.set(true); - } - else { - super.process(node, findResultHolder); - } - } - return findResultHolder.get(); - } - }.process(node, new AtomicBoolean(false)); - } + requireNonNull(node, "node is null"); + requireNonNull(subNode, "subNode is null"); - public static Stream preOrder(Node node) - { - return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new PreOrderIterator(node), 0), false); + return preOrder(node) + .anyMatch(childNode -> childNode == subNode); } - private static final class PreOrderIterator - implements Iterator + public static Stream preOrder(Node node) { - private final Deque remaining = new ArrayDeque<>(); - - public PreOrderIterator(Node node) - { - remaining.push(node); - } - - @Override - public boolean hasNext() - { - return remaining.size() > 0; - } - - @Override - public Node next() - { - Node node = remaining.pop(); - node.getChildren().forEach(remaining::push); - return node; - } + return TreeTraverser.using((Node n) -> unmodifiableIterable(n.getChildren())) + .preOrderTraversal(requireNonNull(node, "node is null")) + .stream(); } private AstUtils() {} diff --git a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java index 7d8c187bfa05f..04895a319aa45 100644 --- a/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/com/facebook/presto/sql/parser/TestSqlParser.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.parser; import com.facebook.presto.sql.tree.AddColumn; +import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.AllColumns; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArrayConstructor; @@ -55,6 +56,7 @@ import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.Grant; import com.facebook.presto.sql.tree.GroupBy; +import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.GroupingSets; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.Insert; @@ -67,6 +69,7 @@ import com.facebook.presto.sql.tree.JoinOn; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.Lateral; import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; @@ -89,6 +92,8 @@ import com.facebook.presto.sql.tree.Rollback; import com.facebook.presto.sql.tree.Rollup; import com.facebook.presto.sql.tree.Row; +import com.facebook.presto.sql.tree.Select; +import com.facebook.presto.sql.tree.SelectItem; import com.facebook.presto.sql.tree.SetSession; import com.facebook.presto.sql.tree.ShowCatalogs; import com.facebook.presto.sql.tree.ShowColumns; @@ -96,6 +101,7 @@ import com.facebook.presto.sql.tree.ShowPartitions; import com.facebook.presto.sql.tree.ShowSchemas; import com.facebook.presto.sql.tree.ShowSession; +import com.facebook.presto.sql.tree.ShowStats; import com.facebook.presto.sql.tree.ShowTables; import com.facebook.presto.sql.tree.SimpleGroupBy; import com.facebook.presto.sql.tree.SingleColumn; @@ -106,6 +112,7 @@ import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.TimeLiteral; import com.facebook.presto.sql.tree.TimestampLiteral; import com.facebook.presto.sql.tree.TransactionAccessMode; @@ -120,6 +127,8 @@ import com.google.common.collect.Lists; import org.testng.annotations.Test; +import java.util.Arrays; +import java.util.List; import java.util.Optional; import static com.facebook.presto.sql.QueryUtil.identifier; @@ -136,6 +145,8 @@ import static com.facebook.presto.sql.testing.TreeAssertions.assertFormattedSql; import static com.facebook.presto.sql.tree.ArithmeticUnaryExpression.negative; import static com.facebook.presto.sql.tree.ArithmeticUnaryExpression.positive; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.LESS_THAN; import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.nCopies; @@ -1066,6 +1077,27 @@ public void testSelectWithGroupBy() Optional.empty(), Optional.empty())); + assertStatement("SELECT a, b, GROUPING(a, b) FROM table1 GROUP BY GROUPING SETS ((a), (b))", + new Query( + Optional.empty(), + new QuerySpecification( + selectList( + DereferenceExpression.from(QualifiedName.of("a")), + DereferenceExpression.from(QualifiedName.of("b")), + new GroupingOperation( + Optional.empty(), + ImmutableList.of(QualifiedName.of("a"), QualifiedName.of("b")) + ) + ), + Optional.of(new Table(QualifiedName.of("table1"))), + Optional.empty(), + Optional.of(new GroupBy(false, ImmutableList.of(new GroupingSets(ImmutableList.of(ImmutableList.of(QualifiedName.of("a")), ImmutableList.of(QualifiedName.of("b"))))))), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty())); + assertStatement("SELECT * FROM table1 GROUP BY ALL GROUPING SETS ((a, b), (a), ()), CUBE (c), ROLLUP (d)", new Query( Optional.empty(), @@ -1107,6 +1139,12 @@ public void testSelectWithGroupBy() Optional.empty())); } + @Test + public void testGroupingFunctionWithExpressions() + { + assertInvalidStatement("SELECT grouping(a+2) FROM (VALUES (1)) AS t (a) GROUP BY a+2", "line 1:18: mismatched input '+' expecting {'.', ')', ','}"); + } + @Test public void testCreateSchema() { @@ -1592,6 +1630,52 @@ public void testUnnest() new Table(QualifiedName.of("t")), new Unnest(ImmutableList.of(new Identifier("a")), true), Optional.empty()))); + assertStatement("SELECT * FROM t FULL JOIN UNNEST(a) ON true", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.FULL, + new Table(QualifiedName.of("t")), + new Unnest(ImmutableList.of(new Identifier("a")), true), + Optional.of(new JoinOn(BooleanLiteral.TRUE_LITERAL))))); + } + + @Test + public void testLateral() + throws Exception + { + Lateral lateralRelation = new Lateral(new Query( + Optional.empty(), + new Values(ImmutableList.of(new LongLiteral("1"))), + Optional.empty(), + Optional.empty())); + + assertStatement("SELECT * FROM t, LATERAL (VALUES 1) a(x)", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.IMPLICIT, + new Table(QualifiedName.of("t")), + new AliasedRelation(lateralRelation, "a", ImmutableList.of("x")), + Optional.empty()))); + + assertStatement("SELECT * FROM t CROSS JOIN LATERAL (VALUES 1) ", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.CROSS, + new Table(QualifiedName.of("t")), + lateralRelation, + Optional.empty()))); + + assertStatement("SELECT * FROM t FULL JOIN LATERAL (VALUES 1) ON true", + simpleQuery( + selectList(new AllColumns()), + new Join( + Join.Type.FULL, + new Table(QualifiedName.of("t")), + lateralRelation, + Optional.of(new JoinOn(BooleanLiteral.TRUE_LITERAL))))); } @Test @@ -1708,6 +1792,12 @@ public void testNonReserved() new Identifier("SOME"), new Identifier("ANY")), table(QualifiedName.of("t")))); + + assertExpression("stats", new Identifier("stats")); + assertExpression("nfd", new Identifier("nfd")); + assertExpression("nfc", new Identifier("nfc")); + assertExpression("nfkd", new Identifier("nfkd")); + assertExpression("nfkc", new Identifier("nfkc")); } @Test @@ -1805,6 +1895,59 @@ private static ExistsPredicate exists(Query query) return new ExistsPredicate(new SubqueryExpression(query)); } + @Test + public void testShowStats() + { + final String[] tableNames = {"t", "s.t", "c.s.t"}; + + for (String fullName : tableNames) { + QualifiedName qualifiedName = QualifiedName.of(Arrays.asList(fullName.split("\\."))); + assertStatement(format("SHOW STATS FOR %s", qualifiedName), new ShowStats(new Table(qualifiedName))); + assertStatement(format("SHOW STATS ON %s", qualifiedName), new ShowStats(new Table(qualifiedName))); + } + } + + @Test + public void testShowStatsForQuery() + { + final String[] tableNames = {"t", "s.t", "c.s.t"}; + + for (String fullName : tableNames) { + QualifiedName qualifiedName = QualifiedName.of(Arrays.asList(fullName.split("\\."))); + assertStatement(format("SHOW STATS FOR (SELECT * FROM %s)", qualifiedName), + createShowStats(qualifiedName, ImmutableList.of(new AllColumns()), Optional.empty())); + assertStatement(format("SHOW STATS FOR (SELECT * FROM %s WHERE field > 0)", qualifiedName), + createShowStats(qualifiedName, + ImmutableList.of(new AllColumns()), + Optional.of( + new ComparisonExpression(GREATER_THAN, + new Identifier("field"), + new LongLiteral("0"))))); + assertStatement(format("SHOW STATS FOR (SELECT * FROM %s WHERE field > 0 or field < 0)", qualifiedName), + createShowStats(qualifiedName, + ImmutableList.of(new AllColumns()), + Optional.of( + new LogicalBinaryExpression(LogicalBinaryExpression.Type.OR, + new ComparisonExpression(GREATER_THAN, + new Identifier("field"), + new LongLiteral("0")), + new ComparisonExpression(LESS_THAN, + new Identifier("field"), + new LongLiteral("0")) + ) + ))); + } + } + + private ShowStats createShowStats(QualifiedName name, List selects, Optional where) + { + return new ShowStats( + new TableSubquery(simpleQuery(new Select(false, selects), + new Table(name), + where, + Optional.empty()))); + } + @Test public void testDescribeOutput() { @@ -1849,7 +1992,7 @@ public void testQuantifiedComparison() { assertExpression("col1 < ANY (SELECT col2 FROM table1)", new QuantifiedComparisonExpression( - ComparisonExpressionType.LESS_THAN, + LESS_THAN, QuantifiedComparisonExpression.Quantifier.ANY, identifier("col1"), new SubqueryExpression(simpleQuery(selectList(new SingleColumn(identifier("col2"))), table(QualifiedName.of("table1")))) @@ -1886,6 +2029,24 @@ private static void assertStatement(String query, Statement expected) assertFormattedSql(SQL_PARSER, expected); } + private static void assertInvalidStatement(String query, String expectedMessage) + { + try { + SQL_PARSER.createStatement(query); + fail(format("Expected statement to fail: %s", query)); + } + catch (RuntimeException ex) { + assertExceptionMessage(query, ex, expectedMessage); + } + } + + private static void assertExceptionMessage(String sql, Exception exception, String expectedMessage) + { + if (!exception.getMessage().equals(expectedMessage)) { + fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), expectedMessage, sql)); + } + } + private static void assertExpression(String expression, Expression expected) { assertParsed(expression, expected, SQL_PARSER.createExpression(expression)); diff --git a/presto-plugin-toolkit/pom.xml b/presto-plugin-toolkit/pom.xml index 7681b0e191a14..eeaa24cadf90d 100644 --- a/presto-plugin-toolkit/pom.xml +++ b/presto-plugin-toolkit/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-plugin-toolkit diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java index 7c8505d3e16af..d62156444afd5 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/AllowAllAccessControl.java @@ -117,12 +117,12 @@ public void checkCanSetCatalogSessionProperty(Identity identity, String property } @Override - public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String grantee, boolean withGrantOption) { } @Override - public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String revokee, boolean grantOptionFor) { } } diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java index defc5e7ed3df4..18dcf8159543e 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/FileBasedAccessControl.java @@ -209,7 +209,7 @@ public void checkCanSetCatalogSessionProperty(Identity identity, String property } @Override - public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String grantee, boolean withGrantOption) { if (!checkTablePermission(identity, tableName, OWNERSHIP)) { denyGrantTablePrivilege(privilege.name(), tableName.toString()); @@ -217,7 +217,7 @@ public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, } @Override - public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String revokee, boolean grantOptionFor) { if (!checkTablePermission(identity, tableName, OWNERSHIP)) { denyRevokeTablePrivilege(privilege.name(), tableName.toString()); diff --git a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java index cb620fad2752a..da7791eeeed27 100644 --- a/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java +++ b/presto-plugin-toolkit/src/main/java/com/facebook/presto/plugin/base/security/ReadOnlyAccessControl.java @@ -143,13 +143,13 @@ public void checkCanSetCatalogSessionProperty(Identity identity, String property } @Override - public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String grantee, boolean withGrantOption) { denyGrantTablePrivilege(privilege.name(), tableName.toString()); } @Override - public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName) + public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, Identity identity, Privilege privilege, SchemaTableName tableName, String revokee, boolean grantOptionFor) { denyRevokeTablePrivilege(privilege.name(), tableName.toString()); } diff --git a/presto-postgresql/pom.xml b/presto-postgresql/pom.xml index f198858aefa3b..1359eb4f865cc 100644 --- a/presto-postgresql/pom.xml +++ b/presto-postgresql/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-postgresql diff --git a/presto-product-tests/README.md b/presto-product-tests/README.md index 675c28200b3b0..a7a97c2b1d2dd 100644 --- a/presto-product-tests/README.md +++ b/presto-product-tests/README.md @@ -221,6 +221,7 @@ groups. | HDFS impersonation | ``hdfs_impersonation`` | ``singlenode-hdfs-impersonation``, ``singlenode-kerberos-hdfs-impersonation`` | | No HDFS impersonation | ``hdfs_no_impersonation`` | ``singlenode``, ``singlenode-kerberos-hdfs-no_impersonation`` | | LDAP | ``ldap`` | ``singlenode-ldap`` | +| SQL Server | ``sqlserver`` | ``singlenode-sqlserver`` | Below is a list of commands that explain how to run these profile specific tests and also the entire test suite: @@ -247,6 +248,12 @@ and also the entire test suite: ``` presto-product-tests/bin/run_on_docker.sh singlenode-ldap -g ldap ``` +* Run **SQL Server** tests: + + ``` + presto-product-tests/bin/run_on_docker.sh singlenode-sqlserver -g sqlserver + ``` + * Run the **entire test suite** excluding all profile specific tests, where <profile> can be any one of the available profiles: @@ -255,7 +262,7 @@ be any one of the available profiles: ``` Note: SQL Server product-tests use `microsoft/mssql-server-linux` docker container. -By running SQL Server product tests you accept the license [ACCEPT_EULA](go.microsoft.com/fwlink/?LinkId=746388) +By running SQL Server product tests you accept the license [ACCEPT_EULA](https://go.microsoft.com/fwlink/?LinkId=746388) ### Running from IntelliJ @@ -321,19 +328,20 @@ The format of `/etc/hosts` entries is ` `: ``` docker inspect $(presto-product-tests/conf/docker/singlenode/compose.sh ps -q hadoop-master) | grep -i IPAddress ``` + Similarly add mappings for MySQL, Postgres and Cassandra containers (`mysql`, `postgres` and `cassandra` hostnames respectively). + To check IPs for those containers run: - Similarly add mappings for MySQL, Postgres and Cassandra containers (`mysql`, `postgres` and `cassandra` hostnames respectively). To check IPs for those containers run: - - ``` - docker inspect $(presto-product-tests/conf/docker/singlenode/compose.sh ps -q mysql) | grep -i IPAddress - docker inspect $(presto-product-tests/conf/docker/singlenode/compose.sh ps -q postgres) | grep -i IPAddress - docker inspect $(presto-product-tests/conf/docker/singlenode/compose.sh ps -q cassandra) | grep -i IPAddress + ``` + docker inspect $(presto-product-tests/conf/docker/singlenode/compose.sh ps -q mysql) | grep -i IPAddress + docker inspect $(presto-product-tests/conf/docker/singlenode/compose.sh ps -q postgres) | grep -i IPAddress + docker inspect $(presto-product-tests/conf/docker/singlenode/compose.sh ps -q cassandra) | grep -i IPAddress + ``` Alternatively you can use below script to obtain hosts ip mapping - ``` - presto-product-tests/bin/hosts.sh singlenode - ``` + ``` + presto-product-tests/bin/hosts.sh singlenode + ``` Note that above command requires [jq](https://stedolan.github.io/jq/) to be installed in your system diff --git a/presto-product-tests/bin/run_on_docker.sh b/presto-product-tests/bin/run_on_docker.sh index eaf0195a08a3a..fcfbfc83d2c96 100755 --- a/presto-product-tests/bin/run_on_docker.sh +++ b/presto-product-tests/bin/run_on_docker.sh @@ -45,7 +45,7 @@ function run_in_application_runner_container() { function check_presto() { run_in_application_runner_container \ java -jar "/docker/volumes/presto-cli/presto-cli-executable.jar" \ - --server presto-master:8080 \ + ${CLI_ARGUMENTS} \ --execute "SHOW CATALOGS" | grep -i hive } @@ -81,8 +81,11 @@ function stop_application_runner_containers() { echo "Container stopped: ${CONTAINER_NAME}" done echo "Removing dead application-runner containers" - docker ps -aq --no-trunc --filter status=dead --filter status=exited --filter name=common_application-runner \ - | xargs docker rm -v || true + local CONTAINERS=`docker ps -aq --no-trunc --filter status=dead --filter status=exited --filter name=common_application-runner` + for CONTAINER in ${CONTAINERS}; + do + docker rm -v "${CONTAINER}" + done } function stop_all_containers() { @@ -177,6 +180,13 @@ shift 1 PRESTO_SERVICES="presto-master" if [[ "$ENVIRONMENT" == "multinode" ]]; then PRESTO_SERVICES="${PRESTO_SERVICES} presto-worker" +elif [[ "$ENVIRONMENT" == "multinode-tls" ]]; then + PRESTO_SERVICES="${PRESTO_SERVICES} presto-worker-1 presto-worker-2" +fi + +CLI_ARGUMENTS="--server presto-master:8080" +if [[ "$ENVIRONMENT" == "multinode-tls" ]]; then + CLI_ARGUMENTS="--server https://presto-master.docker.cluster:7778 --keystore-path /docker/volumes/conf/presto/etc/docker.cluster.jks --keystore-password 123456" fi # check docker and docker compose installation diff --git a/presto-product-tests/conf/docker/common/compose-commons.sh b/presto-product-tests/conf/docker/common/compose-commons.sh index e93ab4280fe28..130cc6cb768a4 100644 --- a/presto-product-tests/conf/docker/common/compose-commons.sh +++ b/presto-product-tests/conf/docker/common/compose-commons.sh @@ -23,8 +23,8 @@ function export_canonical_path() { source ${BASH_SOURCE%/*}/../../../bin/locations.sh -export DOCKER_IMAGES_VERSION=${DOCKER_IMAGES_VERSION:-14} -export HADOOP_MASTER_IMAGE=${HADOOP_MASTER_IMAGE:-"teradatalabs/cdh5-hive:${DOCKER_IMAGES_VERSION}"} +export DOCKER_IMAGES_VERSION=${DOCKER_IMAGES_VERSION:-19} +export HADOOP_MASTER_IMAGE=${HADOOP_MASTER_IMAGE:-"teradatalabs/hdp2.5-hive:${DOCKER_IMAGES_VERSION}"} # The following variables are defined to enable running product tests with arbitrary/downloaded jars # and without building the project. The `presto.env` file should only be sourced if any of the variables diff --git a/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh b/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh index 3538338dba6b8..9134bb678fe36 100755 --- a/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh +++ b/presto-product-tests/conf/docker/files/presto-launcher-wrapper.sh @@ -4,13 +4,14 @@ set -e CONFIG="$1" -if [[ "$CONFIG" != "singlenode" && "$CONFIG" != "multinode-master" && "$CONFIG" != "multinode-worker" && "$CONFIG" != "singlenode-kerberized" && "$CONFIG" != "singlenode-ldap" ]]; then - echo "Usage: launcher-wrapper " +PRESTO_CONFIG_DIRECTORY="/docker/volumes/conf/presto/etc" +CONFIG_PROPERTIES_LOCATION="${PRESTO_CONFIG_DIRECTORY}/${CONFIG}.properties" + +if [[ ! -e ${CONFIG_PROPERTIES_LOCATION} ]]; then + echo "${CONFIG_PROPERTIES_LOCATION} does not exist" exit 1 fi -PRESTO_CONFIG_DIRECTORY="/docker/volumes/conf/presto/etc" - shift 1 /docker/volumes/presto-server/bin/launcher \ diff --git a/presto-product-tests/conf/docker/multinode-tls/compose.sh b/presto-product-tests/conf/docker/multinode-tls/compose.sh new file mode 100755 index 0000000000000..512e1c57d75a3 --- /dev/null +++ b/presto-product-tests/conf/docker/multinode-tls/compose.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +SCRIPT_DIRECTORY=${BASH_SOURCE%/*} + +source ${SCRIPT_DIRECTORY}/../common/compose-commons.sh + +docker-compose \ +-f ${SCRIPT_DIRECTORY}/../common/standard.yml \ +-f ${SCRIPT_DIRECTORY}/../common/jdbc_db.yml \ +-f ${BASH_SOURCE%/*}/../common/cassandra.yml \ +-f ${SCRIPT_DIRECTORY}/docker-compose.yml \ +"$@" diff --git a/presto-product-tests/conf/docker/multinode-tls/docker-compose.yml b/presto-product-tests/conf/docker/multinode-tls/docker-compose.yml new file mode 100644 index 0000000000000..21c2ceb457310 --- /dev/null +++ b/presto-product-tests/conf/docker/multinode-tls/docker-compose.yml @@ -0,0 +1,49 @@ +version: '2' +services: + + presto-master: + domainname: docker.cluster + hostname: presto-master + command: /docker/volumes/conf/docker/files/presto-launcher-wrapper.sh multinode-tls-master run + ports: + - '7778:7778' + networks: + default: + aliases: + - presto-master.docker.cluster + + presto-worker-1: + domainname: docker.cluster + hostname: presto-worker-1 + extends: + file: ../common/standard.yml + service: java-8-base + command: /docker/volumes/conf/docker/files/presto-launcher-wrapper.sh multinode-tls-worker run + networks: + default: + aliases: + - presto-worker-1.docker.cluster + depends_on: + - presto-master + volumes_from: + - presto-master + + presto-worker-2: + domainname: docker.cluster + hostname: presto-worker-2 + extends: + file: ../common/standard.yml + service: java-8-base + command: /docker/volumes/conf/docker/files/presto-launcher-wrapper.sh multinode-tls-worker run + networks: + default: + aliases: + - presto-worker-2.docker.cluster + depends_on: + - presto-master + volumes_from: + - presto-master + + application-runner: + volumes: + - ../../../conf/tempto/tempto-configuration-for-docker-tls.yaml:/docker/volumes/tempto/tempto-configuration-local.yaml diff --git a/presto-product-tests/conf/presto/etc/catalog/cassandra.properties b/presto-product-tests/conf/presto/etc/catalog/cassandra.properties index 65a89f4785c25..3c51fbcb5b1f7 100644 --- a/presto-product-tests/conf/presto/etc/catalog/cassandra.properties +++ b/presto-product-tests/conf/presto/etc/catalog/cassandra.properties @@ -1,3 +1,2 @@ connector.name=cassandra cassandra.contact-points=cassandra -cassandra.schema-cache-ttl=0s diff --git a/presto-product-tests/conf/presto/etc/catalog/mysql.properties b/presto-product-tests/conf/presto/etc/catalog/mysql.properties index db7463f53f77b..bb6b444c91d09 100644 --- a/presto-product-tests/conf/presto/etc/catalog/mysql.properties +++ b/presto-product-tests/conf/presto/etc/catalog/mysql.properties @@ -1,4 +1,4 @@ connector.name=mysql -connection-url=jdbc:mysql://mysql:13306/test +connection-url=jdbc:mysql://mysql:13306 connection-user=root connection-password=swarm diff --git a/presto-product-tests/conf/presto/etc/multinode-tls-master.properties b/presto-product-tests/conf/presto/etc/multinode-tls-master.properties new file mode 100644 index 0000000000000..47a45a7e9ad3b --- /dev/null +++ b/presto-product-tests/conf/presto/etc/multinode-tls-master.properties @@ -0,0 +1,29 @@ +# +# WARNING +# ^^^^^^^ +# This configuration file is for development only and should NOT be used be +# used in production. For example configuration, see the Presto documentation. +# + +node.id=will-be-overwritten +node.environment=test +node.internal-address-source=FQDN + +coordinator=true +node-scheduler.include-coordinator=true +discovery-server.enabled=true +discovery.uri=https://presto-master.docker.cluster:7778 + +query.max-memory=1GB +query.max-memory-per-node=512MB + +http-server.http.enabled=false +http-server.http.port=8080 +http-server.https.enabled=true +http-server.https.port=7778 +http-server.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +http-server.https.keystore.key=123456 + +internal-communication.https.required=true +internal-communication.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +internal-communication.https.keystore.key=123456 diff --git a/presto-product-tests/conf/presto/etc/multinode-tls-worker.properties b/presto-product-tests/conf/presto/etc/multinode-tls-worker.properties new file mode 100644 index 0000000000000..dbe128f62bd9a --- /dev/null +++ b/presto-product-tests/conf/presto/etc/multinode-tls-worker.properties @@ -0,0 +1,28 @@ +# +# WARNING +# ^^^^^^^ +# This configuration file is for development only and should NOT be used be +# used in production. For example configuration, see the Presto documentation. +# + +node.id=will-be-overwritten +node.environment=test +node.internal-address-source=FQDN + +coordinator=false +discovery-server.enabled=false +discovery.uri=https://presto-master.docker.cluster:7778 + +query.max-memory=1GB +query.max-memory-per-node=512MB + +http-server.http.enabled=false +http-server.http.port=8080 +http-server.https.enabled=true +http-server.https.port=7778 +http-server.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +http-server.https.keystore.key=123456 + +internal-communication.https.required=true +internal-communication.https.keystore.path=/docker/volumes/conf/presto/etc/docker.cluster.jks +internal-communication.https.keystore.key=123456 diff --git a/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml index b5a8d0c825a2a..8abcf5cb2c775 100644 --- a/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml +++ b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-kerberos.yaml @@ -22,8 +22,6 @@ databases: host: presto-master.docker.cluster port: 7778 server_address: https://${databases.presto.host}:${databases.presto.port} - # Use the HTTP interface in JDBC, as Kerberos authentication is not yet supported in there. - jdbc_url: jdbc:presto://${databases.presto.host}:8080/hive/${databases.hive.schema} # jdbc_user in here should satisfy two requirements in order to pass SQL standard access control checks in Presto: # 1) It should belong to the "admin" role in hive @@ -39,3 +37,13 @@ databases: cli_kerberos_service_name: presto-server cli_kerberos_use_canonical_hostname: false configured_hdfs_user: hdfs + + jdbc_url: "jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/${databases.hive.schema}?\ + SSL=true&\ + SSLTrustStorePath=${databases.presto.https_keystore_path}&\ + SSLTrustStorePassword=${databases.presto.https_keystore_password}&\ + KerberosRemoteServiceName=${databases.presto.cli_kerberos_service_name}&\ + KerberosPrincipal=${databases.presto.cli_kerberos_principal}&\ + KerberosUseCanonicalHostname=${databases.presto.cli_kerberos_use_canonical_hostname}&\ + KerberosConfigPath=${databases.presto.cli_kerberos_config_path}&\ + KerberosKeytabPath=${databases.presto.cli_kerberos_keytab}" diff --git a/presto-product-tests/conf/tempto/tempto-configuration-for-docker-tls.yaml b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-tls.yaml new file mode 100644 index 0000000000000..a86a12f00efa6 --- /dev/null +++ b/presto-product-tests/conf/tempto/tempto-configuration-for-docker-tls.yaml @@ -0,0 +1,16 @@ +databases: + hive: + host: hadoop-master + presto: + host: presto-master.docker.cluster + port: 7778 + http_port: 8080 + https_port: ${databases.presto.port} + server_address: https://${databases.presto.host}:${databases.presto.port} + jdbc_url: "jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/${databases.hive.schema}?\ + SSL=true&\ + SSLTrustStorePath=${databases.presto.https_keystore_path}&\ + SSLTrustStorePassword=${databases.presto.https_keystore_password}" + configured_hdfs_user: hive + https_keystore_path: /docker/volumes/conf/presto/etc/docker.cluster.jks + https_keystore_password: '123456' diff --git a/presto-product-tests/pom.xml b/presto-product-tests/pom.xml index 71352145b60c7..55ef1c8cbc14e 100644 --- a/presto-product-tests/pom.xml +++ b/presto-product-tests/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-product-tests diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java index 8b2bad68e4031..bfa83472587fb 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/TestGroups.java @@ -28,6 +28,8 @@ public final class TestGroups public static final String BLACKHOLE_CONNECTOR = "blackhole"; public static final String SMOKE = "smoke"; public static final String JDBC = "jdbc"; + public static final String MYSQL = "mysql"; + public static final String PRESTO_JDBC = "presto_jdbc"; public static final String SIMBA_JDBC = "simba_jdbc"; public static final String QUERY_ENGINE = "qe"; public static final String COMPARISON = "comparison"; @@ -55,6 +57,8 @@ public final class TestGroups public static final String SQL_SERVER = "sqlserver"; public static final String LDAP = "ldap"; public static final String LDAP_CLI = "ldap_cli"; + public static final String SKIP_ON_CDH = "skip_on_cdh"; + public static final String TLS = "tls"; private TestGroups() {} } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/TlsTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/TlsTests.java new file mode 100644 index 0000000000000..09e9eecfe13b3 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/TlsTests.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.google.common.base.Throwables; +import com.google.inject.Inject; +import com.google.inject.name.Named; +import com.teradata.tempto.query.QueryResult; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketTimeoutException; +import java.net.URI; +import java.util.List; + +import static com.facebook.presto.tests.TestGroups.PROFILE_SPECIFIC_TESTS; +import static com.facebook.presto.tests.TestGroups.TLS; +import static com.facebook.presto.tests.utils.QueryExecutors.onPresto; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.stream.Collectors.toList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +public class TlsTests +{ + @Inject(optional = true) + @Named("databases.presto.http_port") + private Integer httpPort; + + @Inject(optional = true) + @Named("databases.presto.https_port") + private Integer httpsPort; + + @Test(groups = {TLS, PROFILE_SPECIFIC_TESTS}) + public void testHttpPortIsClosed() + throws Exception + { + assertThat(httpPort).isNotNull(); + assertThat(httpsPort).isNotNull(); + + waitForNodeRefresh(); + List activeNodesUrls = getActiveNodesUrls(); + assertThat(activeNodesUrls).hasSize(3); + + List hosts = activeNodesUrls.stream() + .map((uri) -> URI.create(uri).getHost()) + .collect(toList()); + + for (String host : hosts) { + assertPortIsOpen(host, httpsPort); + assertPortIsClosed(host, httpPort); + } + } + + private void waitForNodeRefresh() + throws InterruptedException + { + long deadline = System.currentTimeMillis() + MINUTES.toMillis(1); + while (System.currentTimeMillis() < deadline) { + if (getActiveNodesUrls().size() == 3) { + return; + } + Thread.sleep(100); + } + fail("Worker nodes haven't been discovered in 1 minutes."); + } + + private List getActiveNodesUrls() + { + QueryResult queryResult = onPresto() + .executeQuery("SELECT http_uri FROM system.runtime.nodes"); + return queryResult.rows() + .stream() + .map((row) -> row.get(0).toString()) + .collect(toList()); + } + + private static void assertPortIsClosed(String host, Integer port) + { + if (isPortOpen(host, port)) { + fail(format("Port %d at %s is expected to be closed", port, host)); + } + } + + private static void assertPortIsOpen(String host, Integer port) + { + if (!isPortOpen(host, port)) { + fail(format("Port %d at %s is expected to be open", port, host)); + } + } + + private static boolean isPortOpen(String host, Integer port) + { + try (Socket socket = new Socket()) { + socket.connect(new InetSocketAddress(InetAddress.getByName(host), port), 1000); + return true; + } + catch (ConnectException | SocketTimeoutException e) { + return false; + } + catch (IOException e) { + throw Throwables.propagate(e); + } + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/cassandra/TestInsertIntoCassandraTable.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/cassandra/TestInsertIntoCassandraTable.java new file mode 100644 index 0000000000000..d32890adfb52c --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/cassandra/TestInsertIntoCassandraTable.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.cassandra; + +import com.teradata.tempto.ProductTest; +import com.teradata.tempto.Requirement; +import com.teradata.tempto.RequirementsProvider; +import com.teradata.tempto.configuration.Configuration; +import com.teradata.tempto.internal.fulfillment.table.TableName; +import com.teradata.tempto.query.QueryResult; +import io.airlift.units.Duration; +import org.testng.annotations.Test; + +import static com.facebook.presto.tests.TestGroups.CASSANDRA; +import static com.facebook.presto.tests.cassandra.DataTypesTableDefinition.CASSANDRA_ALL_TYPES; +import static com.facebook.presto.tests.cassandra.TestConstants.CONNECTOR_NAME; +import static com.facebook.presto.tests.cassandra.TestConstants.KEY_SPACE; +import static com.facebook.presto.tests.utils.QueryAssertions.assertContainsEventually; +import static com.teradata.tempto.assertions.QueryAssert.Row.row; +import static com.teradata.tempto.assertions.QueryAssert.assertThat; +import static com.teradata.tempto.fulfillment.table.MutableTableRequirement.State.CREATED; +import static com.teradata.tempto.fulfillment.table.MutableTablesState.mutableTablesState; +import static com.teradata.tempto.fulfillment.table.TableRequirements.mutableTable; +import static com.teradata.tempto.query.QueryExecutor.query; +import static com.teradata.tempto.util.DateTimeUtils.parseTimestampInUTC; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MINUTES; + +public class TestInsertIntoCassandraTable + extends ProductTest + implements RequirementsProvider +{ + private static final String CASSANDRA_INSERT_TABLE = "Insert_All_Types"; + + @Override + public Requirement getRequirements(Configuration configuration) + { + return mutableTable(CASSANDRA_ALL_TYPES, CASSANDRA_INSERT_TABLE, CREATED); + } + + @Test(groups = CASSANDRA) + public void testInsertIntoValuesToCassandraTableAllSimpleTypes() + throws Exception + { + TableName table = mutableTablesState().get(CASSANDRA_INSERT_TABLE).getTableName(); + String tableNameInDatabase = String.format("%s.%s", CONNECTOR_NAME, table.getNameInDatabase()); + + assertContainsEventually(() -> query(format("SHOW TABLES FROM %s.%s", CONNECTOR_NAME, KEY_SPACE)), + query(format("SELECT '%s'", table.getSchemalessNameInDatabase())), + new Duration(1, MINUTES)); + + QueryResult queryResult = query("SELECT * FROM " + tableNameInDatabase); + assertThat(queryResult).hasNoRows(); + + // TODO Following types are not supported now. We need to change null into the value after fixing it + // blob, frozen>, inet, list, map, set, timeuuid, decimal, uuid, varint + query("INSERT INTO " + tableNameInDatabase + + "(a, b, bl, bo, d, do, f, fr, i, integer, l, m, s, t, ti, tu, u, v, vari) VALUES (" + + "'ascii value', " + + "BIGINT '99999', " + + "null, " + + "true, " + + "null, " + + "123.456789, " + + "REAL '123.45678', " + + "null, " + + "null, " + + "123, " + + "null, " + + "null, " + + "null, " + + "'text value', " + + "timestamp '9999-12-31 23:59:59'," + + "null, " + + "null, " + + "'varchar value'," + + "null)"); + + assertThat(query("SELECT * FROM " + tableNameInDatabase)).containsOnly( + row( + "ascii value", + 99999, + null, + true, + null, + 123.456789, + 123.45678, + null, + null, + 123, + null, + null, + null, + "text value", + parseTimestampInUTC("9999-12-31 23:59:59"), + null, + null, + "varchar value", + null)); + + // insert null for all datatypes + query("INSERT INTO " + tableNameInDatabase + + "(a, b, bl, bo, d, do, f, fr, i, integer, l, m, s, t, ti, tu, u, v, vari) VALUES (" + + "'key 1', null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null) "); + assertThat(query(format("SELECT * FROM %s WHERE a = 'key 1'", tableNameInDatabase))).containsOnly( + row("key 1", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)); + + // insert into only a subset of columns + query(format("INSERT INTO %s (a, bo, integer, t) VALUES ('key 2', false, 999, 'text 2')", tableNameInDatabase)); + assertThat(query(format("SELECT * FROM %s WHERE a = 'key 2'", tableNameInDatabase))).containsOnly( + row("key 2", null, null, false, null, null, null, null, null, 999, null, null, null, "text 2", null, null, null, null, null)); + + // negative test: failed to insert null to primary key + assertThat(() -> query(format("INSERT INTO %s (a) VALUES (null) ", tableNameInDatabase))) + .failsWithMessage("Invalid null value in condition for column a"); + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java index 1ef7f8cf3ebc6..fe2c8694e7a31 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/cli/PrestoLdapCliTests.java @@ -202,7 +202,7 @@ public void shouldFailQueryForLdapWithoutPassword() "--truststore-password", ldapTruststorePassword, "--user", ldapUserName, "--execute", "select * from hive.default.nation;"); - assertTrue(trimLines(presto.readRemainingErrorLines()).stream().anyMatch(str -> str.contains("statusMessage=Unauthorized"))); + assertTrue(trimLines(presto.readRemainingErrorLines()).stream().anyMatch(str -> str.contains("Authentication failed: Unauthorized"))); } @Test(groups = {LDAP, LDAP_CLI, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java index 72e99a9489224..05b889d859dc1 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java @@ -91,7 +91,6 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBui " smallint_to_bigint SMALLINT," + " int_to_bigint INT," + " bigint_to_varchar BIGINT," + - " varchar_to_integer STRING," + " float_to_double FLOAT" + ") " + "PARTITIONED BY (id BIGINT) " + @@ -111,7 +110,6 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder parquetTableDefini " smallint_to_bigint SMALLINT," + " int_to_bigint INT," + " bigint_to_varchar BIGINT," + - " varchar_to_integer STRING," + " float_to_double DOUBLE" + ") " + "PARTITIONED BY (id BIGINT) " + @@ -216,8 +214,8 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition) executeHiveQuery(format("INSERT INTO TABLE %s " + "PARTITION (id=1) " + "VALUES" + - "(-1, 2, -3, 100, -101, 2323, 12345, '-1025', 0.5)," + - "(1, -2, null, -100, 101, -2323, -12345, '99999999999999999999999999999', -1.5)", + "(-1, 2, -3, 100, -101, 2323, 12345, 0.5)," + + "(1, -2, null, -100, 101, -2323, -12345, -1.5)", tableName)); alterTableColumnTypes(tableName); @@ -234,7 +232,6 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition) -101L, 2323L, "12345", - -1025, 0.5, 1), row( @@ -245,7 +242,6 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition) 101L, -2323L, "-12345", - null, -1.5, 1)); } @@ -260,7 +256,6 @@ private void assertProperAlteredTableSchema(String tableName) row("smallint_to_bigint", "bigint"), row("int_to_bigint", "bigint"), row("bigint_to_varchar", "varchar"), - row("varchar_to_integer", "integer"), row("float_to_double", "double"), row("id", "bigint") ); @@ -278,7 +273,6 @@ private void assertColumnTypes(QueryResult queryResult) BIGINT, BIGINT, LONGNVARCHAR, - INTEGER, DOUBLE, BIGINT ); @@ -292,7 +286,6 @@ else if (usingTeradataJdbcDriver(connection)) { BIGINT, BIGINT, VARBINARY, - INTEGER, DOUBLE, BIGINT ); @@ -311,7 +304,6 @@ private static void alterTableColumnTypes(String tableName) executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN smallint_to_bigint smallint_to_bigint bigint", tableName)); executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN int_to_bigint int_to_bigint bigint", tableName)); executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN bigint_to_varchar bigint_to_varchar string", tableName)); - executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN varchar_to_integer varchar_to_integer int", tableName)); executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN float_to_double float_to_double double", tableName)); } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java new file mode 100644 index 0000000000000..25c076b886e7b --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java @@ -0,0 +1,388 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.hive; + +import com.teradata.tempto.ProductTest; +import com.teradata.tempto.Requirement; +import com.teradata.tempto.Requirements; +import com.teradata.tempto.RequirementsProvider; +import com.teradata.tempto.Requires; +import com.teradata.tempto.configuration.Configuration; +import com.teradata.tempto.fulfillment.table.MutableTableRequirement; +import com.teradata.tempto.query.QueryExecutor; +import org.testng.annotations.Test; + +import static com.facebook.presto.tests.TestGroups.HIVE_CONNECTOR; +import static com.facebook.presto.tests.TestGroups.SKIP_ON_CDH; +import static com.facebook.presto.tests.hive.AllSimpleTypesTableDefinitions.ALL_HIVE_SIMPLE_TYPES_TEXTFILE; +import static com.facebook.presto.tests.hive.HiveTableDefinitions.NATION_PARTITIONED_BY_REGIONKEY; +import static com.teradata.tempto.assertions.QueryAssert.Row.row; +import static com.teradata.tempto.assertions.QueryAssert.anyOf; +import static com.teradata.tempto.assertions.QueryAssert.assertThat; +import static com.teradata.tempto.context.ThreadLocalTestContextHolder.testContext; +import static com.teradata.tempto.fulfillment.table.MutableTablesState.mutableTablesState; +import static com.teradata.tempto.fulfillment.table.TableRequirements.mutableTable; +import static com.teradata.tempto.fulfillment.table.hive.tpch.TpchTableDefinitions.NATION; +import static com.teradata.tempto.query.QueryExecutor.query; + +public class TestHiveTableStatistics + extends ProductTest +{ + private static class UnpartitionedNationTable + implements RequirementsProvider + { + @Override + public Requirement getRequirements(Configuration configuration) + { + return mutableTable(NATION); + } + } + + private static class PartitionedNationTable + implements RequirementsProvider + { + @Override + public Requirement getRequirements(Configuration configuration) + { + return mutableTable(NATION_PARTITIONED_BY_REGIONKEY); + } + } + + private static final String ALL_TYPES_TABLE_NAME = "all_types"; + private static final String EMPTY_ALL_TYPES_TABLE_NAME = "empty_all_types"; + + private static final class AllTypesTable + implements RequirementsProvider + { + @Override + public Requirement getRequirements(Configuration configuration) + { + return Requirements.compose( + mutableTable(ALL_HIVE_SIMPLE_TYPES_TEXTFILE, ALL_TYPES_TABLE_NAME, MutableTableRequirement.State.LOADED), + mutableTable(ALL_HIVE_SIMPLE_TYPES_TEXTFILE, EMPTY_ALL_TYPES_TABLE_NAME, MutableTableRequirement.State.CREATED)); + } + } + + @Test(groups = {HIVE_CONNECTOR}) + @Requires(UnpartitionedNationTable.class) + public void testStatisticsForUnpartitionedTable() + { + String tableNameInDatabase = mutableTablesState().get(NATION.getName()).getNameInDatabase(); + + String showStatsWholeTable = "SHOW STATS FOR " + tableNameInDatabase; + + // table not analyzed + + assertThat(query(showStatsWholeTable)).containsOnly( + row("n_nationkey", null, null, null, null), + row("n_name", null, null, null, null), + row("n_regionkey", null, null, null, null), + row("n_comment", null, null, null, null), + row(null, null, null, null, anyOf(null, 0.0))); // anyOf because of different behaviour on HDP (hive 1.2) and CDH (hive 1.1) + + // basic analysis + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); + + assertThat(query(showStatsWholeTable)).containsOnly( + row("n_nationkey", null, null, null, null), + row("n_name", null, null, null, null), + row("n_regionkey", null, null, null, null), + row("n_comment", null, null, null, null), + row(null, null, null, null, 25.0)); + + // column analysis + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); + + assertThat(query(showStatsWholeTable)).containsOnly( + row("n_nationkey", null, 19.0, 0.0, null), + row("n_name", null, 24.0, 0.0, null), + row("n_regionkey", null, 5.0, 0.0, null), + row("n_comment", null, 31.0, 0.0, null), + row(null, null, null, null, 25.0)); + } + + @Test(groups = {HIVE_CONNECTOR}) + @Requires(PartitionedNationTable.class) + public void testStatisticsForPartitionedTable() + { + String tableNameInDatabase = mutableTablesState().get(NATION_PARTITIONED_BY_REGIONKEY.getName()).getNameInDatabase(); + + String showStatsWholeTable = "SHOW STATS FOR " + tableNameInDatabase; + String showStatsPartitionOne = "SHOW STATS FOR (SELECT * FROM " + tableNameInDatabase + " WHERE p_regionkey = 1)"; + String showStatsPartitionTwo = "SHOW STATS FOR (SELECT * FROM " + tableNameInDatabase + " WHERE p_regionkey = 2)"; + + // table not analyzed + + assertThat(query(showStatsWholeTable)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 3.0, null, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, null)); + + assertThat(query(showStatsPartitionOne)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 1.0, null, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, null)); + + // basic analysis for single partition + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"1\") COMPUTE STATISTICS"); + + assertThat(query(showStatsWholeTable)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, 15.0)); + + assertThat(query(showStatsPartitionOne)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, 5.0)); + + assertThat(query(showStatsPartitionTwo)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 1.0, null, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, null)); + + // basic analysis for all partitions + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS"); + + assertThat(query(showStatsWholeTable)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, 15.0)); + + assertThat(query(showStatsPartitionOne)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, 5.0)); + + assertThat(query(showStatsPartitionTwo)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, 5.0)); + + // column analysis for single partition + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"1\") COMPUTE STATISTICS FOR COLUMNS"); + + assertThat(query(showStatsWholeTable)).containsOnly( + row("p_nationkey", null, 5.0, 0.0, null), + row("p_name", null, 6.0, 0.0, null), + row("p_regionkey", null, 3.0, 0.0, null), + row("p_comment", null, 1.0, 0.0, null), + row(null, null, null, null, 15.0)); + + assertThat(query(showStatsPartitionOne)).containsOnly( + row("p_nationkey", null, 5.0, 0.0, null), + row("p_name", null, 6.0, 0.0, null), + row("p_regionkey", null, 1.0, 0.0, null), + row("p_comment", null, 1.0, 0.0, null), + row(null, null, null, null, 5.0)); + + assertThat(query(showStatsPartitionTwo)).containsOnly( + row("p_nationkey", null, null, null, null), + row("p_name", null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null), + row("p_comment", null, null, null, null), + row(null, null, null, null, 5.0)); + + // column analysis for all partitions + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS FOR COLUMNS"); + + assertThat(query(showStatsWholeTable)).containsOnly( + row("p_nationkey", null, 5.0, 0.0, null), + row("p_name", null, 6.0, 0.0, null), + row("p_regionkey", null, 3.0, 0.0, null), + row("p_comment", null, 1.0, 0.0, null), + row(null, null, null, null, 15.0)); + + assertThat(query(showStatsPartitionOne)).containsOnly( + row("p_nationkey", null, 5.0, 0.0, null), + row("p_name", null, 6.0, 0.0, null), + row("p_regionkey", null, 1.0, 0.0, null), + row("p_comment", null, 1.0, 0.0, null), + row(null, null, null, null, 5.0)); + + assertThat(query(showStatsPartitionTwo)).containsOnly( + row("p_nationkey", null, 4.0, 0.0, null), + row("p_name", null, 6.0, 0.0, null), + row("p_regionkey", null, 1.0, 0.0, null), + row("p_comment", null, 1.0, 0.0, null), + row(null, null, null, null, 5.0)); + } + + @Test(groups = {HIVE_CONNECTOR, SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats + @Requires(AllTypesTable.class) + public void testStatisticsForAllDataTypes() + { + String tableNameInDatabase = mutableTablesState().get(ALL_TYPES_TABLE_NAME).getNameInDatabase(); + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); + + assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( + row("c_tinyint", null, null, null, null), + row("c_smallint", null, null, null, null), + row("c_int", null, null, null, null), + row("c_bigint", null, null, null, null), + row("c_float", null, null, null, null), + row("c_double", null, null, null, null), + row("c_decimal", null, null, null, null), + row("c_decimal_w_params", null, null, null, null), + row("c_timestamp", null, null, null, null), + row("c_date", null, null, null, null), + row("c_string", null, null, null, null), + row("c_varchar", null, null, null, null), + row("c_char", null, null, null, null), + row("c_boolean", null, null, null, null), + row("c_binary", null, null, null, null), + row(null, null, null, null, 1.0)); + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); + + assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( + row("c_tinyint", null, 1.0, 0.0, null), + row("c_smallint", null, 1.0, 0.0, null), + row("c_int", null, 2.0, 0.0, null), + row("c_bigint", null, 1.0, 0.0, null), + row("c_float", null, 1.0, 0.0, null), + row("c_double", null, 1.0, 0.0, null), + row("c_decimal", null, 1.0, 0.0, null), + row("c_decimal_w_params", null, 1.0, 0.0, null), + row("c_timestamp", null, 1.0, 0.0, null), + row("c_date", null, 2.0, 0.0, null), + row("c_string", null, 1.0, 0.0, null), + row("c_varchar", null, 1.0, 0.0, null), + row("c_char", null, 1.0, 0.0, null), + row("c_boolean", null, 1.0, 0.0, null), + row("c_binary", null, null, 0.0, null), + row(null, null, null, null, 1.0)); + } + + @Test(groups = {HIVE_CONNECTOR, SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats + @Requires(AllTypesTable.class) + public void testStatisticsForAllDataTypesNoData() + { + String tableNameInDatabase = mutableTablesState().get(EMPTY_ALL_TYPES_TABLE_NAME).getNameInDatabase(); + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); + + assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( + row("c_tinyint", null, null, null, null), + row("c_smallint", null, null, null, null), + row("c_int", null, null, null, null), + row("c_bigint", null, null, null, null), + row("c_float", null, null, null, null), + row("c_double", null, null, null, null), + row("c_decimal", null, null, null, null), + row("c_decimal_w_params", null, null, null, null), + row("c_timestamp", null, null, null, null), + row("c_date", null, null, null, null), + row("c_string", null, null, null, null), + row("c_varchar", null, null, null, null), + row("c_char", null, null, null, null), + row("c_boolean", null, null, null, null), + row("c_binary", null, null, null, null), + row(null, null, null, null, 0.0)); + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); + + assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( + row("c_tinyint", null, 0.0, 0.0, null), + row("c_smallint", null, 0.0, 0.0, null), + row("c_int", null, 0.0, 0.0, null), + row("c_bigint", null, 0.0, 0.0, null), + row("c_float", null, 0.0, 0.0, null), + row("c_double", null, 0.0, 0.0, null), + row("c_decimal", null, 0.0, 0.0, null), + row("c_decimal_w_params", null, 0.0, 0.0, null), + row("c_timestamp", null, 0.0, 0.0, null), + row("c_date", null, 0.0, 0.0, null), + row("c_string", null, 0.0, 0.0, null), + row("c_varchar", null, 0.0, 0.0, null), + row("c_char", null, 0.0, 0.0, null), + row("c_boolean", null, 0.0, 0.0, null), + row("c_binary", null, null, 0.0, null), + row(null, null, null, null, 0.0)); + } + + @Test(groups = {HIVE_CONNECTOR, SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats + @Requires(AllTypesTable.class) + public void testStatisticsForAllDataTypesOnlyNulls() + { + String tableNameInDatabase = mutableTablesState().get(EMPTY_ALL_TYPES_TABLE_NAME).getNameInDatabase(); + onHive().executeQuery("INSERT INTO TABLE " + tableNameInDatabase + " VALUES(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)"); + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); + + assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( + row("c_tinyint", null, null, null, null), + row("c_smallint", null, null, null, null), + row("c_int", null, null, null, null), + row("c_bigint", null, null, null, null), + row("c_float", null, null, null, null), + row("c_double", null, null, null, null), + row("c_decimal", null, null, null, null), + row("c_decimal_w_params", null, null, null, null), + row("c_timestamp", null, null, null, null), + row("c_date", null, null, null, null), + row("c_string", null, null, null, null), + row("c_varchar", null, null, null, null), + row("c_char", null, null, null, null), + row("c_boolean", null, null, null, null), + row("c_binary", null, null, null, null), + row(null, null, null, null, 1.0)); + + onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); + + assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( + row("c_tinyint", null, 1.0, 1.0, null), + row("c_smallint", null, 1.0, 1.0, null), + row("c_int", null, 1.0, 1.0, null), + row("c_bigint", null, 1.0, 1.0, null), + row("c_float", null, 1.0, 1.0, null), + row("c_double", null, 1.0, 1.0, null), + row("c_decimal", null, 1.0, 1.0, null), + row("c_decimal_w_params", null, 1.0, 1.0, null), + row("c_timestamp", null, 1.0, 1.0, null), + row("c_date", null, 1.0, 1.0, null), + row("c_string", null, 1.0, 1.0, null), + row("c_varchar", null, 1.0, 1.0, null), + row("c_char", null, 1.0, 1.0, null), + row("c_boolean", null, 0.0, 1.0, null), + row("c_binary", null, null, 1.0, null), + row(null, null, null, null, 1.0)); + } + + private static QueryExecutor onHive() + { + return testContext().getDependency(QueryExecutor.class, "hive"); + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapJdbcTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapJdbcTests.java new file mode 100644 index 0000000000000..d5a1a3bf8f209 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapJdbcTests.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.jdbc; + +import com.google.inject.Inject; +import com.google.inject.name.Named; +import com.teradata.tempto.ProductTest; +import com.teradata.tempto.Requirement; +import com.teradata.tempto.RequirementsProvider; +import com.teradata.tempto.configuration.Configuration; +import com.teradata.tempto.fulfillment.ldap.LdapObjectRequirement; +import com.teradata.tempto.query.QueryResult; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; + +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.AMERICA_ORG; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ASIA_ORG; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ORPHAN_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP_USER; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +public abstract class LdapJdbcTests + extends ProductTest + implements RequirementsProvider +{ + protected static final long TIMEOUT = 30 * 1000; // seconds per test + + protected static final String NATION_SELECT_ALL_QUERY = "select * from tpch.tiny.nation"; + + @Inject + @Named("databases.presto.cli_ldap_truststore_path") + protected String ldapTruststorePath; + + @Inject + @Named("databases.presto.cli_ldap_truststore_password") + protected String ldapTruststorePassword; + + @Inject + @Named("databases.presto.cli_ldap_user_name") + protected String ldapUserName; + + @Inject + @Named("databases.presto.cli_ldap_user_password") + protected String ldapUserPassword; + + @Inject + @Named("databases.presto.cli_ldap_server_address") + private String prestoServer; + + @Override + public Requirement getRequirements(Configuration configuration) + { + return new LdapObjectRequirement( + Arrays.asList( + AMERICA_ORG, ASIA_ORG, + DEFAULT_GROUP, PARENT_GROUP, CHILD_GROUP, + DEFAULT_GROUP_USER, PARENT_GROUP_USER, CHILD_GROUP_USER, ORPHAN_USER + )); + } + + protected void expectQueryToFail(String user, String password, String message) + { + try { + executeLdapQuery(NATION_SELECT_ALL_QUERY, user, password); + fail(); + } + catch (SQLException exception) { + assertEquals(exception.getMessage(), message); + } + } + + protected QueryResult executeLdapQuery(String query, String name, String password) + throws SQLException + { + try (Connection connection = getLdapConnection(name, password)) { + Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(query); + return QueryResult.forResultSet(rs); + } + } + + private Connection getLdapConnection(String name, String password) + throws SQLException + { + return DriverManager.getConnection(getLdapUrl(), name, password); + } + + protected String prestoServer() + { + String prefix = "https://"; + checkState(prestoServer.startsWith(prefix), "invalid server address: %s", prestoServer); + return prestoServer.substring(prefix.length()); + } + + protected String getLdapUrl() + { + return format(getLdapUrlFormat(), prestoServer(), ldapTruststorePath, ldapTruststorePassword); + } + + protected abstract String getLdapUrlFormat(); +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapPrestoJdbcTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapPrestoJdbcTests.java new file mode 100644 index 0000000000000..6b9664c70f4c1 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapPrestoJdbcTests.java @@ -0,0 +1,136 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.jdbc; + +import com.teradata.tempto.Requires; +import com.teradata.tempto.fulfillment.table.hive.tpch.ImmutableTpchTablesRequirements.ImmutableNationTable; +import org.testng.annotations.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ORPHAN_USER; +import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP_USER; +import static com.facebook.presto.tests.TestGroups.LDAP; +import static com.facebook.presto.tests.TestGroups.PRESTO_JDBC; +import static com.facebook.presto.tests.TestGroups.PROFILE_SPECIFIC_TESTS; +import static com.facebook.presto.tests.TpchTableResults.PRESTO_NATION_RESULT; +import static com.teradata.tempto.assertions.QueryAssert.assertThat; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +public class LdapPrestoJdbcTests + extends LdapJdbcTests +{ + @Override + protected String getLdapUrlFormat() + { + return "jdbc:presto://%s?SSL=true&SSLTrustStorePath=%s&SSLTrustStorePassword=%s"; + } + + @Requires(ImmutableNationTable.class) + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldRunQueryWithLdap() + throws SQLException + { + assertThat(executeLdapQuery(NATION_SELECT_ALL_QUERY, ldapUserName, ldapUserPassword)).matches(PRESTO_NATION_RESULT); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapUserInChildGroup() + { + String name = CHILD_GROUP_USER.getAttributes().get("cn"); + expectQueryToFailForUserNotInGroup(name); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapUserInParentGroup() + { + String name = PARENT_GROUP_USER.getAttributes().get("cn"); + expectQueryToFailForUserNotInGroup(name); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForOrphanLdapUser() + { + String name = ORPHAN_USER.getAttributes().get("cn"); + expectQueryToFailForUserNotInGroup(name); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForWrongLdapPassword() + { + expectQueryToFail(ldapUserName, "wrong_password", "Authentication failed: Invalid credentials: [LDAP: error code 49 - Invalid Credentials]"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForWrongLdapUser() + { + expectQueryToFail("invalid_user", ldapUserPassword, "Authentication failed: Invalid credentials: [LDAP: error code 49 - Invalid Credentials]"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForEmptyUser() + { + expectQueryToFail("", ldapUserPassword, "Connection property 'user' value is empty"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapWithoutPassword() + { + expectQueryToFail(ldapUserName, null, "Authentication failed: Unauthorized"); + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailQueryForLdapWithoutSsl() + { + try { + DriverManager.getConnection("jdbc:presto://" + prestoServer(), ldapUserName, ldapUserPassword); + fail(); + } + catch (SQLException exception) { + assertEquals(exception.getMessage(), "Authentication using username/password requires SSL to be enabled"); + } + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailForIncorrectTrustStore() + { + try { + String url = format("jdbc:presto://%s?SSL=true&SSLTrustStorePath=%s&SSLTrustStorePassword=%s", prestoServer(), ldapTruststorePath, "wrong_password"); + Connection connection = DriverManager.getConnection(url, ldapUserName, ldapUserPassword); + Statement statement = connection.createStatement(); + statement.executeQuery(NATION_SELECT_ALL_QUERY); + fail(); + } + catch (SQLException exception) { + assertEquals(exception.getMessage(), "Error setting up SSL: Keystore was tampered with, or password was incorrect"); + } + } + + @Test(groups = {LDAP, PRESTO_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) + public void shouldFailForUserWithColon() + { + expectQueryToFail("UserWith:Colon", ldapUserPassword, "Illegal character ':' found in username"); + } + + private void expectQueryToFailForUserNotInGroup(String user) + { + expectQueryToFail(user, ldapUserPassword, format("Authentication failed: Unauthorized user: User %s not a member of the authorized group", user)); + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapSimbaJdbcTests.java similarity index 65% rename from presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapTests.java rename to presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapSimbaJdbcTests.java index ac9018e6cd2f4..29bdfd3bc9150 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/LdapSimbaJdbcTests.java @@ -13,35 +13,18 @@ */ package com.facebook.presto.tests.jdbc; -import com.google.inject.Inject; -import com.google.inject.name.Named; -import com.teradata.tempto.BeforeTestWithContext; -import com.teradata.tempto.ProductTest; -import com.teradata.tempto.Requirement; -import com.teradata.tempto.RequirementsProvider; import com.teradata.tempto.Requires; -import com.teradata.tempto.configuration.Configuration; -import com.teradata.tempto.fulfillment.ldap.LdapObjectRequirement; import com.teradata.tempto.fulfillment.table.hive.tpch.ImmutableTpchTablesRequirements.ImmutableNationTable; -import com.teradata.tempto.query.QueryResult; import org.testng.annotations.Test; import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; -import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; -import java.util.Arrays; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.AMERICA_ORG; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ASIA_ORG; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP; import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.CHILD_GROUP_USER; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.DEFAULT_GROUP_USER; import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.ORPHAN_USER; -import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP; import static com.facebook.presto.tests.ImmutableLdapObjectDefinitions.PARENT_GROUP_USER; import static com.facebook.presto.tests.TestGroups.LDAP; import static com.facebook.presto.tests.TestGroups.PROFILE_SPECIFIC_TESTS; @@ -51,14 +34,9 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; -public class LdapTests - extends ProductTest - implements RequirementsProvider - +public class LdapSimbaJdbcTests + extends LdapJdbcTests { - private static final long TIMEOUT = 300 * 1000; // 30 secs per test - - private static final String NATION_SELECT_ALL_QUERY = "select * from tpch.tiny.nation"; private static final String JDBC_URL_FORMAT = "jdbc:presto://%s;AuthenticationType=LDAP Authentication;" + "SSLTrustStorePath=%s;SSLTrustStorePwd=%s;AllowSelfSignedServerCert=1;AllowHostNameCNMismatch=1"; private static final String SSL_CERTIFICATE_ERROR = @@ -72,42 +50,10 @@ public class LdapTests private static final String INVALID_SSL_PROPERTY = "[Teradata][Presto](100200) Connection string is invalid: SSL value is not valid for given AuthenticationType."; - @Inject - @Named("databases.presto.cli_ldap_truststore_path") - private String ldapTruststorePath; - - @Inject - @Named("databases.presto.cli_ldap_truststore_password") - private String ldapTruststorePassword; - - @Inject - @Named("databases.presto.cli_ldap_user_name") - private String ldapUserName; - - @Inject - @Named("databases.presto.cli_ldap_user_password") - private String ldapUserPassword; - - @Inject - @Named("databases.presto.cli_ldap_server_address") - private String prestoServer; - - @BeforeTestWithContext - public void setup() - throws SQLException - { - prestoServer = prestoServer.substring(8); - } - @Override - public Requirement getRequirements(Configuration configuration) + protected String getLdapUrlFormat() { - return new LdapObjectRequirement( - Arrays.asList( - AMERICA_ORG, ASIA_ORG, - DEFAULT_GROUP, PARENT_GROUP, CHILD_GROUP, - DEFAULT_GROUP_USER, PARENT_GROUP_USER, CHILD_GROUP_USER, ORPHAN_USER - )); + return JDBC_URL_FORMAT; } @Requires(ImmutableNationTable.class) @@ -188,10 +134,10 @@ public void shouldFailForIncorrectTrustStore() throws IOException, InterruptedException { try { - String url = String.format(JDBC_URL_FORMAT, prestoServer, ldapTruststorePath, "wrong_password"); + String url = String.format(JDBC_URL_FORMAT, prestoServer(), ldapTruststorePath, "wrong_password"); Connection connection = DriverManager.getConnection(url, ldapUserName, ldapUserPassword); Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery(NATION_SELECT_ALL_QUERY); + statement.executeQuery(NATION_SELECT_ALL_QUERY); fail(); } catch (SQLException exception) { @@ -201,7 +147,7 @@ public void shouldFailForIncorrectTrustStore() @Test(groups = {LDAP, SIMBA_JDBC, PROFILE_SPECIFIC_TESTS}, timeOut = TIMEOUT) public void shouldFailForUserWithColon() - throws SQLException, InterruptedException + throws SQLException, InterruptedException { expectQueryToFail("UserWith:Colon", ldapUserPassword, MALFORMED_CREDENTIALS_ERROR); } @@ -210,36 +156,4 @@ private void expectQueryToFailForUserNotInGroup(String user) { expectQueryToFail(user, ldapUserPassword, UNAUTHORIZED_USER_ERROR); } - - private void expectQueryToFail(String user, String password, String message) - { - try { - executeLdapQuery(NATION_SELECT_ALL_QUERY, user, password); - fail(); - } - catch (SQLException exception) { - assertEquals(exception.getMessage(), message); - } - } - - private QueryResult executeLdapQuery(String query, String name, String password) - throws SQLException - { - try (Connection connection = getLdapConnection(name, password)) { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery(query); - return QueryResult.forResultSet(rs); - } - } - - private Connection getLdapConnection(String name, String password) - throws SQLException - { - return DriverManager.getConnection(getLdapUrl(), name, password); - } - - private String getLdapUrl() - { - return String.format(JDBC_URL_FORMAT, prestoServer, ldapTruststorePath, ldapTruststorePassword); - } } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/mysql/CreateTableAsSelect.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/mysql/CreateTableAsSelect.java new file mode 100644 index 0000000000000..46b4e92601e50 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/mysql/CreateTableAsSelect.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.mysql; + +import com.teradata.tempto.AfterTestWithContext; +import com.teradata.tempto.BeforeTestWithContext; +import com.teradata.tempto.ProductTest; +import com.teradata.tempto.Requires; +import com.teradata.tempto.fulfillment.table.hive.tpch.ImmutableTpchTablesRequirements.ImmutableNationTable; +import com.teradata.tempto.query.QueryResult; +import io.airlift.log.Logger; +import org.testng.annotations.Test; + +import static com.facebook.presto.tests.TestGroups.JDBC; +import static com.facebook.presto.tests.TestGroups.MYSQL; +import static com.facebook.presto.tests.utils.QueryExecutors.onMySql; +import static com.teradata.tempto.assertions.QueryAssert.Row.row; +import static com.teradata.tempto.assertions.QueryAssert.assertThat; +import static com.teradata.tempto.query.QueryExecutor.query; +import static java.lang.String.format; + +public class CreateTableAsSelect + extends ProductTest +{ + private static final String TABLE_NAME = "test.nation_tmp"; + + @BeforeTestWithContext + @AfterTestWithContext + public void dropTestTable() + { + try { + onMySql().executeQuery(format("DROP TABLE IF EXISTS %s", TABLE_NAME)); + } + catch (Exception e) { + Logger.get(getClass()).warn(e, "failed to drop table"); + } + } + + @Requires(ImmutableNationTable.class) + @Test(groups = {JDBC, MYSQL}) + public void testCreateTableAsSelect() + { + QueryResult queryResult = query(format("CREATE TABLE mysql.%s AS SELECT * FROM nation", TABLE_NAME)); + assertThat(queryResult).containsOnly(row(25)); + } +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryAssertions.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryAssertions.java new file mode 100644 index 0000000000000..9b941dce35163 --- /dev/null +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryAssertions.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.utils; + +import com.google.common.base.Joiner; +import com.google.common.collect.Iterables; +import com.teradata.tempto.query.QueryResult; +import io.airlift.units.Duration; + +import java.util.function.Supplier; + +import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; +import static io.airlift.units.Duration.nanosSince; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.testng.Assert.fail; + +public class QueryAssertions +{ + public static void assertContainsEventually(Supplier all, QueryResult expectedSubset, Duration timeout) + { + long start = System.nanoTime(); + while (!Thread.currentThread().isInterrupted()) { + try { + assertContains(all.get(), expectedSubset); + return; + } + catch (AssertionError e) { + if (nanosSince(start).compareTo(timeout) > 0) { + throw e; + } + } + sleepUninterruptibly(50, MILLISECONDS); + } + } + + public static void assertContains(QueryResult all, QueryResult expectedSubset) + { + for (Object row : expectedSubset.rows()) { + if (!all.rows().contains(row)) { + fail(format("expected row missing: %s\nAll %s rows:\n %s\nExpected subset %s rows:\n %s\n", + row, + all.getRowsCount(), + Joiner.on("\n ").join(Iterables.limit(all.rows(), 100)), + expectedSubset.getRowsCount(), + Joiner.on("\n ").join(Iterables.limit(expectedSubset.rows(), 100)))); + } + } + } + + private QueryAssertions() {} +} diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryExecutors.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryExecutors.java index 93e0c642fa784..6a80ba0eafb4e 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryExecutors.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/utils/QueryExecutors.java @@ -39,5 +39,10 @@ public static QueryExecutor onSqlServer() return testContext().getDependency(QueryExecutor.class, "sqlserver"); } + public static QueryExecutor onMySql() + { + return testContext().getDependency(QueryExecutor.class, "mysql"); + } + private QueryExecutors() {} } diff --git a/presto-product-tests/src/main/resources/tempto-configuration.yaml b/presto-product-tests/src/main/resources/tempto-configuration.yaml index d6ea5880700f3..0eb0898971dd8 100644 --- a/presto-product-tests/src/main/resources/tempto-configuration.yaml +++ b/presto-product-tests/src/main/resources/tempto-configuration.yaml @@ -29,7 +29,7 @@ databases: jdbc_driver_class: com.facebook.presto.jdbc.PrestoDriver jdbc_url: jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/${databases.hive.schema} jdbc_user: hdfs - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false presto_tpcds: @@ -37,7 +37,7 @@ databases: jdbc_driver_class: com.facebook.presto.jdbc.PrestoDriver jdbc_url: jdbc:presto://${databases.presto.host}:${databases.presto.port}/hive/tpcds jdbc_user: hdfs - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false alice@presto: @@ -47,7 +47,7 @@ databases: jdbc_driver_class: ${databases.presto.jdbc_driver_class} jdbc_url: ${databases.presto.jdbc_url} jdbc_user: alice - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false https_keystore_path: ${databases.presto.https_keystore_path} https_keystore_password: ${databases.presto.https_keystore_password} @@ -59,7 +59,7 @@ databases: jdbc_driver_class: ${databases.presto.jdbc_driver_class} jdbc_url: ${databases.presto.jdbc_url} jdbc_user: bob - jdbc_password: na + jdbc_password: "***empty***" jdbc_pooling: false https_keystore_path: ${databases.presto.https_keystore_path} https_keystore_password: ${databases.presto.https_keystore_password} diff --git a/presto-raptor/pom.xml b/presto-raptor/pom.xml index 848441a1039b1..5a5b49ed4ee0e 100644 --- a/presto-raptor/pom.xml +++ b/presto-raptor/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-raptor diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorErrorCode.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorErrorCode.java index 90219817fd9fc..6b105e63090e8 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorErrorCode.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorErrorCode.java @@ -34,7 +34,8 @@ public enum RaptorErrorCode RAPTOR_REASSIGNMENT_THROTTLE(9, EXTERNAL), RAPTOR_RECOVERY_TIMEOUT(10, EXTERNAL), RAPTOR_CORRUPT_METADATA(11, EXTERNAL), - RAPTOR_LOCAL_DISK_FULL(12, EXTERNAL); + RAPTOR_LOCAL_DISK_FULL(12, EXTERNAL), + RAPTOR_BACKUP_CORRUPTION(13, EXTERNAL); private final ErrorCode errorCode; diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java index aef3c77940006..984ed34bd11c4 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/RaptorMetadata.java @@ -474,6 +474,10 @@ public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHan @Override public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional layout) { + if (viewExists(session, tableMetadata.getTable())) { + throw new PrestoException(ALREADY_EXISTS, "View already exists: " + tableMetadata.getTable()); + } + Optional partitioning = layout .map(ConnectorNewTableLayout::getPartitioning) .map(RaptorPartitioningHandle.class::cast); @@ -692,7 +696,9 @@ public Optional finishInsert(ConnectorSession session, List columns = handle.getColumnHandles().stream().map(ColumnInfo::fromHandle).collect(toList()); long updateTime = session.getStartTime(); - shardManager.commitShards(transactionId, tableId, columns, parseFragments(fragments), externalBatchId, updateTime); + Collection shards = parseFragments(fragments); + log.info("Committing insert into tableId %s (queryId: %s, shards: %s, columns: %s)", handle.getTableId(), session.getQueryId(), shards.size(), columns.size()); + shardManager.commitShards(transactionId, tableId, columns, shards, externalBatchId, updateTime); clearRollback(); @@ -772,6 +778,10 @@ public void createView(ConnectorSession session, SchemaTableName viewName, Strin String schemaName = viewName.getSchemaName(); String tableName = viewName.getTableName(); + if (getTableHandle(viewName) != null) { + throw new PrestoException(ALREADY_EXISTS, "Table already exists: " + viewName); + } + if (replace) { daoTransaction(dbi, MetadataDao.class, dao -> { dao.dropView(schemaName, tableName); diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/backup/BackupManager.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/backup/BackupManager.java index 41129557a3dcd..ddd958b504ce4 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/backup/BackupManager.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/backup/BackupManager.java @@ -14,7 +14,11 @@ package com.facebook.presto.raptor.backup; import com.facebook.presto.raptor.storage.BackupStats; +import com.facebook.presto.raptor.storage.StorageService; +import com.facebook.presto.spi.PrestoException; import com.google.common.base.Throwables; +import com.google.common.io.Files; +import io.airlift.log.Logger; import io.airlift.units.DataSize; import io.airlift.units.Duration; import org.weakref.jmx.Flatten; @@ -30,6 +34,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; +import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_BACKUP_CORRUPTION; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.units.DataSize.Unit.BYTE; @@ -40,23 +45,27 @@ public class BackupManager { + private static final Logger log = Logger.get(BackupManager.class); + private final Optional backupStore; + private final StorageService storageService; private final ExecutorService executorService; private final AtomicInteger pendingBackups = new AtomicInteger(); private final BackupStats stats = new BackupStats(); @Inject - public BackupManager(Optional backupStore, BackupConfig config) + public BackupManager(Optional backupStore, StorageService storageService, BackupConfig config) { - this(backupStore, config.getBackupThreads()); + this(backupStore, storageService, config.getBackupThreads()); } - public BackupManager(Optional backupStore, int backupThreads) + public BackupManager(Optional backupStore, StorageService storageService, int backupThreads) { checkArgument(backupThreads > 0, "backupThreads must be > 0"); this.backupStore = requireNonNull(backupStore, "backupStore is null"); + this.storageService = requireNonNull(storageService, "storageService is null"); this.executorService = newFixedThreadPool(backupThreads, daemonThreadsNamed("background-shard-backup-%s")); } @@ -104,6 +113,29 @@ public void run() backupStore.get().backupShard(uuid, source); stats.addCopyShardDataRate(new DataSize(source.length(), BYTE), Duration.nanosSince(start)); + + File restored = new File(storageService.getStagingFile(uuid) + ".validate"); + backupStore.get().restoreShard(uuid, restored); + + if (!Files.equal(source, restored)) { + stats.incrementBackupCorruption(); + + File quarantineBase = storageService.getQuarantineFile(uuid); + File quarantineOriginal = new File(quarantineBase.getPath() + ".original"); + File quarantineRestored = new File(quarantineBase.getPath() + ".restored"); + + log.error("Backup is corrupt after write. Quarantining local file: %s", quarantineBase); + if (!source.renameTo(quarantineOriginal) || !restored.renameTo(quarantineRestored)) { + log.warn("Quarantine of corrupt backup shard failed: %s", uuid); + } + + throw new PrestoException(RAPTOR_BACKUP_CORRUPTION, "Backup is corrupt after write: " + uuid); + } + + if (!restored.delete()) { + log.warn("Failed to delete staging file: %s", restored); + } + stats.incrementBackupSuccess(); } catch (Throwable t) { diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java index 927f25f6cf0ae..758aa1c75d09a 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/DatabaseShardManager.java @@ -72,6 +72,7 @@ import static com.facebook.presto.raptor.util.ArrayUtil.intArrayToBytes; import static com.facebook.presto.raptor.util.DatabaseUtil.bindOptionalInt; import static com.facebook.presto.raptor.util.DatabaseUtil.isSyntaxOrAccessError; +import static com.facebook.presto.raptor.util.DatabaseUtil.isTransactionCacheFullError; import static com.facebook.presto.raptor.util.DatabaseUtil.metadataError; import static com.facebook.presto.raptor.util.DatabaseUtil.runIgnoringConstraintViolation; import static com.facebook.presto.raptor.util.DatabaseUtil.runTransaction; @@ -365,6 +366,10 @@ private void runCommit(long transactionId, HandleConsumer callback) return; } catch (DBIException e) { + if (isTransactionCacheFullError(e)) { + throw metadataError(e, "Transaction too large"); + } + propagateIfInstanceOf(e.getCause(), PrestoException.class); if (attempt == maxAttempts) { throw metadataError(e); @@ -518,6 +523,12 @@ private Map toNodeIdMap(Collection shards) return Maps.toMap(identifiers, this::getOrCreateNodeId); } + @Override + public ShardMetadata getShard(UUID shardUuid) + { + return dao.getShard(shardUuid); + } + @Override public Set getNodeShards(String nodeIdentifier) { @@ -758,8 +769,8 @@ private static List insertShards(Connection connection, long tableId, List throws SQLException { String sql = "" + - "INSERT INTO shards (shard_uuid, table_id, create_time, row_count, compressed_size, uncompressed_size, bucket_number)\n" + - "VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?, ?, ?)"; + "INSERT INTO shards (shard_uuid, table_id, create_time, row_count, compressed_size, uncompressed_size, xxhash64, bucket_number)\n" + + "VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?, ?, ?, ?)"; try (PreparedStatement statement = connection.prepareStatement(sql, RETURN_GENERATED_KEYS)) { for (ShardInfo shard : shards) { @@ -768,7 +779,8 @@ private static List insertShards(Connection connection, long tableId, List statement.setLong(3, shard.getRowCount()); statement.setLong(4, shard.getCompressedSize()); statement.setLong(5, shard.getUncompressedSize()); - bindOptionalInt(statement, 6, shard.getBucketNumber()); + statement.setLong(6, shard.getXxhash64()); + bindOptionalInt(statement, 7, shard.getBucketNumber()); statement.addBatch(); } statement.executeBatch(); diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/SchemaDao.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/SchemaDao.java index 31b5879bba264..952dbffcbef2f 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/SchemaDao.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/SchemaDao.java @@ -90,9 +90,10 @@ public interface SchemaDao " row_count BIGINT NOT NULL,\n" + " compressed_size BIGINT NOT NULL,\n" + " uncompressed_size BIGINT NOT NULL,\n" + + " xxhash64 BIGINT NOT NULL,\n" + " UNIQUE (shard_uuid),\n" + // include a covering index organized by table_id - " UNIQUE (table_id, bucket_number, shard_id, shard_uuid, create_time, row_count, compressed_size, uncompressed_size),\n" + + " UNIQUE (table_id, bucket_number, shard_id, shard_uuid, create_time, row_count, compressed_size, uncompressed_size, xxhash64),\n" + " FOREIGN KEY (table_id) REFERENCES tables (table_id)\n" + ")") void createTableShards(); diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardDao.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardDao.java index b44a83a80e53a..22a6f43353a35 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardDao.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardDao.java @@ -36,6 +36,8 @@ public interface ShardDao int CLEANABLE_SHARDS_BATCH_SIZE = 1000; int CLEANUP_TRANSACTIONS_BATCH_SIZE = 10_000; + String SHARD_METADATA_COLUMNS = "table_id, shard_id, shard_uuid, bucket_number, row_count, compressed_size, uncompressed_size, xxhash64"; + @SqlUpdate("INSERT INTO nodes (node_identifier) VALUES (:nodeIdentifier)") @GetGeneratedKeys int insertNode(@Bind("nodeIdentifier") String nodeIdentifier); @@ -59,7 +61,11 @@ public interface ShardDao @Mapper(RaptorNode.Mapper.class) List getNodes(); - @SqlQuery("SELECT table_id, shard_id, shard_uuid, bucket_number, row_count, compressed_size, uncompressed_size\n" + + @SqlQuery("SELECT " + SHARD_METADATA_COLUMNS + " FROM shards WHERE shard_uuid = :shardUuid") + @Mapper(ShardMetadata.Mapper.class) + ShardMetadata getShard(@Bind("shardUuid") UUID shardUuid); + + @SqlQuery("SELECT " + SHARD_METADATA_COLUMNS + "\n" + "FROM (\n" + " SELECT s.*\n" + " FROM shards s\n" + diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardInfo.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardInfo.java index 2e6ac712cdf8e..34b5faab38ce4 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardInfo.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardInfo.java @@ -25,6 +25,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class ShardInfo @@ -36,6 +37,7 @@ public class ShardInfo private final long rowCount; private final long compressedSize; private final long uncompressedSize; + private final long xxhash64; @JsonCreator public ShardInfo( @@ -45,7 +47,8 @@ public ShardInfo( @JsonProperty("columnStats") List columnStats, @JsonProperty("rowCount") long rowCount, @JsonProperty("compressedSize") long compressedSize, - @JsonProperty("uncompressedSize") long uncompressedSize) + @JsonProperty("uncompressedSize") long uncompressedSize, + @JsonProperty("xxhash64") long xxhash64) { this.shardUuid = requireNonNull(shardUuid, "shardUuid is null"); this.bucketNumber = requireNonNull(bucketNumber, "bucketNumber is null"); @@ -58,6 +61,8 @@ public ShardInfo( this.rowCount = rowCount; this.compressedSize = compressedSize; this.uncompressedSize = uncompressedSize; + + this.xxhash64 = xxhash64; } @JsonProperty @@ -102,6 +107,12 @@ public long getUncompressedSize() return uncompressedSize; } + @JsonProperty + public long getXxhash64() + { + return xxhash64; + } + @Override public String toString() { @@ -113,6 +124,7 @@ public String toString() .add("rowCount", rowCount) .add("compressedSize", compressedSize) .add("uncompressedSize", uncompressedSize) + .add("xxhash64", format("%016x", xxhash64)) .omitNullValues() .toString(); } diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardManager.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardManager.java index 683986c9bb0c5..fd258d4b32d96 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardManager.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardManager.java @@ -52,6 +52,11 @@ public interface ShardManager */ void replaceShardUuids(long transactionId, long tableId, List columns, Set oldShardUuids, Collection newShards, OptionalLong updateTime); + /** + * Get shard metadata for a shard. + */ + ShardMetadata getShard(UUID shardUuid); + /** * Get shard metadata for shards on a given node. */ diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardMetadata.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardMetadata.java index 9e5ba724c6bdc..13e56dfcd437f 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardMetadata.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardMetadata.java @@ -25,10 +25,12 @@ import java.util.UUID; import static com.facebook.presto.raptor.util.DatabaseUtil.getOptionalInt; +import static com.facebook.presto.raptor.util.DatabaseUtil.getOptionalLong; import static com.facebook.presto.raptor.util.UuidUtil.uuidFromBytes; import static com.google.common.base.MoreObjects.ToStringHelper; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class ShardMetadata @@ -40,6 +42,7 @@ public class ShardMetadata private final long rowCount; private final long compressedSize; private final long uncompressedSize; + private final OptionalLong xxhash64; private final OptionalLong rangeStart; private final OptionalLong rangeEnd; @@ -51,6 +54,7 @@ public ShardMetadata( long rowCount, long compressedSize, long uncompressedSize, + OptionalLong xxhash64, OptionalLong rangeStart, OptionalLong rangeEnd) { @@ -67,6 +71,7 @@ public ShardMetadata( this.rowCount = rowCount; this.compressedSize = compressedSize; this.uncompressedSize = uncompressedSize; + this.xxhash64 = requireNonNull(xxhash64, "xxhash64 is null"); this.rangeStart = requireNonNull(rangeStart, "rangeStart is null"); this.rangeEnd = requireNonNull(rangeEnd, "rangeEnd is null"); } @@ -106,6 +111,11 @@ public long getUncompressedSize() return uncompressedSize; } + public OptionalLong getXxhash64() + { + return xxhash64; + } + public OptionalLong getRangeStart() { return rangeStart; @@ -126,6 +136,7 @@ public ShardMetadata withTimeRange(long rangeStart, long rangeEnd) rowCount, compressedSize, uncompressedSize, + xxhash64, OptionalLong.of(rangeStart), OptionalLong.of(rangeEnd)); } @@ -143,6 +154,9 @@ public String toString() if (bucketNumber.isPresent()) { stringHelper.add("bucketNumber", bucketNumber.getAsInt()); } + if (xxhash64.isPresent()) { + stringHelper.add("xxhash64", format("%16x", xxhash64.getAsLong())); + } if (rangeStart.isPresent()) { stringHelper.add("rangeStart", rangeStart.getAsLong()); } @@ -168,6 +182,7 @@ public boolean equals(Object o) Objects.equals(rowCount, that.rowCount) && Objects.equals(compressedSize, that.compressedSize) && Objects.equals(uncompressedSize, that.uncompressedSize) && + Objects.equals(xxhash64, that.xxhash64) && Objects.equals(shardUuid, that.shardUuid) && Objects.equals(rangeStart, that.rangeStart) && Objects.equals(rangeEnd, that.rangeEnd); @@ -184,6 +199,7 @@ public int hashCode() rowCount, compressedSize, uncompressedSize, + xxhash64, rangeStart, rangeEnd); } @@ -203,6 +219,7 @@ public ShardMetadata map(int index, ResultSet r, StatementContext ctx) r.getLong("row_count"), r.getLong("compressed_size"), r.getLong("uncompressed_size"), + getOptionalLong(r, "xxhash64"), OptionalLong.empty(), OptionalLong.empty()); } diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/BackupStats.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/BackupStats.java index caffeb06f5b53..fa3d088eb6fc8 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/BackupStats.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/BackupStats.java @@ -34,6 +34,7 @@ public class BackupStats private final CounterStat backupSuccess = new CounterStat(); private final CounterStat backupFailure = new CounterStat(); + private final CounterStat backupCorruption = new CounterStat(); public void addCopyShardDataRate(DataSize size, Duration duration) { @@ -58,6 +59,11 @@ public void incrementBackupFailure() backupFailure.update(1); } + public void incrementBackupCorruption() + { + backupCorruption.update(1); + } + @Managed @Nested public DistributionStat getCopyToBackupBytesPerSecond() @@ -99,4 +105,11 @@ public CounterStat getBackupFailure() { return backupFailure; } + + @Managed + @Nested + public CounterStat getBackupCorruption() + { + return backupCorruption; + } } diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/FileStorageService.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/FileStorageService.java index bb7fb87f315ba..7ad6944188879 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/FileStorageService.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/FileStorageService.java @@ -46,6 +46,7 @@ public class FileStorageService private final File baseStorageDir; private final File baseStagingDir; + private final File baseQuarantineDir; @Inject public FileStorageService(StorageManagerConfig config) @@ -58,6 +59,7 @@ public FileStorageService(File dataDirectory) File baseDataDir = requireNonNull(dataDirectory, "dataDirectory is null"); this.baseStorageDir = new File(baseDataDir, "storage"); this.baseStagingDir = new File(baseDataDir, "staging"); + this.baseQuarantineDir = new File(baseDataDir, "quarantine"); } @Override @@ -68,6 +70,7 @@ public void start() deleteStagingFilesAsync(); createParents(new File(baseStagingDir, ".")); createParents(new File(baseStorageDir, ".")); + createParents(new File(baseQuarantineDir, ".")); } @Override @@ -96,6 +99,13 @@ public File getStagingFile(UUID shardUuid) return new File(baseStagingDir, name); } + @Override + public File getQuarantineFile(UUID shardUuid) + { + String name = getFileSystemPath(new File("/"), shardUuid).getName(); + return new File(baseQuarantineDir, name); + } + @Override public Set getStorageShards() { diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java index 938e4f1dedc18..d419f1b502665 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/OrcStorageManager.java @@ -53,6 +53,7 @@ import io.airlift.json.JsonCodec; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.airlift.slice.XxHash64; import io.airlift.units.DataSize; import io.airlift.units.Duration; @@ -61,8 +62,10 @@ import java.io.Closeable; import java.io.File; +import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; import java.util.ArrayList; import java.util.BitSet; @@ -106,6 +109,7 @@ import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.json.JsonCodec.jsonCodec; +import static io.airlift.units.DataSize.Unit.PETABYTE; import static java.lang.Math.min; import static java.nio.file.StandardCopyOption.ATOMIC_MOVE; import static java.util.Objects.requireNonNull; @@ -122,6 +126,8 @@ public class OrcStorageManager private static final JsonCodec SHARD_DELTA_CODEC = jsonCodec(ShardDelta.class); private static final long MAX_ROWS = 1_000_000_000; + // TODO: do not limit the max size of blocks to read for now; enable the limit when the Hive connector is ready + private static final DataSize HUGE_MAX_READ_BLOCK_SIZE = new DataSize(1, PETABYTE); private static final JsonCodec METADATA_CODEC = jsonCodec(OrcFileMetadata.class); private final String nodeId; @@ -225,7 +231,7 @@ public ConnectorPageSource getPageSource( AggregatedMemoryContext systemMemoryUsage = new AggregatedMemoryContext(); try { - OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), readerAttributes.getMaxMergeDistance(), readerAttributes.getMaxReadSize()); + OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), readerAttributes.getMaxMergeDistance(), readerAttributes.getMaxReadSize(), HUGE_MAX_READ_BLOCK_SIZE); Map indexMap = columnIdIndex(reader.getColumnNames()); ImmutableMap.Builder includedColumns = ImmutableMap.builder(); @@ -359,13 +365,13 @@ private static FileOrcDataSource fileOrcDataSource(ReaderAttributes readerAttrib private ShardInfo createShardInfo(UUID shardUuid, OptionalInt bucketNumber, File file, Set nodes, long rowCount, long uncompressedSize) { - return new ShardInfo(shardUuid, bucketNumber, nodes, computeShardStats(file), rowCount, file.length(), uncompressedSize); + return new ShardInfo(shardUuid, bucketNumber, nodes, computeShardStats(file), rowCount, file.length(), uncompressedSize, xxhash64(file)); } private List computeShardStats(File file) { try (OrcDataSource dataSource = fileOrcDataSource(defaultReaderAttributes, file)) { - OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), defaultReaderAttributes.getMaxMergeDistance(), defaultReaderAttributes.getMaxReadSize()); + OrcReader reader = new OrcReader(dataSource, new OrcMetadataReader(), defaultReaderAttributes.getMaxMergeDistance(), defaultReaderAttributes.getMaxReadSize(), HUGE_MAX_READ_BLOCK_SIZE); ImmutableList.Builder list = ImmutableList.builder(); for (ColumnInfo info : getColumnInfo(reader)) { @@ -453,6 +459,16 @@ private List getColumnInfoFromOrcColumnTypes(List orcColumnN return list.build(); } + static long xxhash64(File file) + { + try (InputStream in = new FileInputStream(file)) { + return XxHash64.hash(in); + } + catch (IOException e) { + throw new PrestoException(RAPTOR_ERROR, "Failed to read file: " + file, e); + } + } + private static Optional getOrcFileMetadata(OrcReader reader) { return Optional.ofNullable(reader.getFooter().getUserMetadata().get(OrcFileMetadata.KEY)) diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java index b21029d28ff09..b8ab117f7476c 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/Row.java @@ -42,9 +42,9 @@ public class Row { private final List columns; - private final int sizeInBytes; + private final long sizeInBytes; - public Row(List columns, int sizeInBytes) + public Row(List columns, long sizeInBytes) { this.columns = requireNonNull(columns, "columns is null"); checkArgument(sizeInBytes >= 0, "sizeInBytes must be >= 0"); @@ -56,7 +56,7 @@ public List getColumns() return columns; } - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @@ -70,7 +70,7 @@ public static Row extractRow(Page page, int position, List types) for (int channel = 0; channel < page.getChannelCount(); channel++) { Block block = page.getBlock(channel); Type type = types.get(channel); - int size; + long size; Object value = getNativeContainerValue(type, block, position); if (value == null) { size = SIZE_OF_BYTE; @@ -180,7 +180,7 @@ private static Object nativeContainerToOrcValue(Type type, Object nativeValue) private static class RowBuilder { - private int rowSize; + private long rowSize; private final List columns; public RowBuilder(int columnCount) @@ -188,7 +188,7 @@ public RowBuilder(int columnCount) this.columns = new ArrayList<>(columnCount); } - public void add(Object value, int size) + public void add(Object value, long size) { columns.add(value); rowSize += size; diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryManager.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryManager.java index d36fb37b3df14..3d6d4203d4967 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryManager.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryManager.java @@ -55,7 +55,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; +import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_BACKUP_CORRUPTION; +import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_ERROR; import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_RECOVERY_ERROR; +import static com.facebook.presto.raptor.storage.OrcStorageManager.xxhash64; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.Threads.daemonThreadsNamed; @@ -165,7 +168,7 @@ private synchronized void enqueueMissingShards() for (ShardMetadata shard : getMissingShards()) { stats.incrementBackgroundShardRecovery(); Futures.addCallback( - shardQueue.submit(MissingShard.createBackgroundMissingShard(shard.getShardUuid(), shard.getCompressedSize())), + shardQueue.submit(new MissingShard(shard.getShardUuid(), shard.getCompressedSize(), shard.getXxhash64(), false)), failureCallback(t -> log.warn(t, "Error recovering shard: %s", shard.getShardUuid()))); } } @@ -190,13 +193,16 @@ private boolean shardNeedsRecovery(UUID shardUuid, long shardSize) public Future recoverShard(UUID shardUuid) throws ExecutionException { - requireNonNull(shardUuid, "shardUuid is null"); + ShardMetadata shard = shardManager.getShard(shardUuid); + if (shard == null) { + throw new PrestoException(RAPTOR_ERROR, "Shard does not exist in database: " + shardUuid); + } stats.incrementActiveShardRecovery(); - return shardQueue.submit(MissingShard.createActiveMissingShard(shardUuid)); + return shardQueue.submit(new MissingShard(shardUuid, shard.getCompressedSize(), shard.getXxhash64(), true)); } @VisibleForTesting - void restoreFromBackup(UUID shardUuid, OptionalLong shardSize) + void restoreFromBackup(UUID shardUuid, long shardSize, OptionalLong shardXxhash64) { File storageFile = storageService.getStorageFile(shardUuid); @@ -206,11 +212,16 @@ void restoreFromBackup(UUID shardUuid, OptionalLong shardSize) } if (storageFile.exists()) { - if (!shardSize.isPresent() || (storageFile.length() == shardSize.getAsLong())) { + if (!isFileCorrupt(storageFile, shardSize, shardXxhash64)) { return; } - log.warn("Local shard file is corrupt. Deleting local file: %s", storageFile); - storageFile.delete(); + stats.incrementCorruptLocalFile(); + File quarantine = getQuarantineFile(shardUuid); + log.error("Local file is corrupt. Quarantining local file: %s", quarantine); + if (!storageFile.renameTo(quarantine)) { + log.warn("Quarantine of corrupt local file failed: %s", shardUuid); + storageFile.delete(); + } } // create a temporary file in the staging directory @@ -253,16 +264,37 @@ void restoreFromBackup(UUID shardUuid, OptionalLong shardSize) stagingFile.delete(); } - if (!storageFile.exists() || (shardSize.isPresent() && (storageFile.length() != shardSize.getAsLong()))) { + if (!storageFile.exists()) { + stats.incrementShardRecoveryFailure(); + throw new PrestoException(RAPTOR_RECOVERY_ERROR, "File does not exist after recovery: " + shardUuid); + } + + if (isFileCorrupt(storageFile, shardSize, shardXxhash64)) { stats.incrementShardRecoveryFailure(); - log.info("Files do not match after recovery. Deleting local file: " + shardUuid); - storageFile.delete(); - throw new PrestoException(RAPTOR_RECOVERY_ERROR, "File not recovered correctly: " + shardUuid); + stats.incrementCorruptRecoveredFile(); + File quarantine = getQuarantineFile(shardUuid); + log.error("Local file is corrupt after recovery. Quarantining local file: %s", quarantine); + if (!storageFile.renameTo(quarantine)) { + log.warn("Quarantine of corrupt recovered file failed: %s", shardUuid); + storageFile.delete(); + } + throw new PrestoException(RAPTOR_BACKUP_CORRUPTION, "Backup is corrupt after read: " + shardUuid); } stats.incrementShardRecoverySuccess(); } + private File getQuarantineFile(UUID shardUuid) + { + File file = storageService.getQuarantineFile(shardUuid); + return new File(file.getPath() + ".corrupt." + System.currentTimeMillis()); + } + + private static boolean isFileCorrupt(File file, long size, OptionalLong xxhash64) + { + return (file.length() != size) || (xxhash64.isPresent() && (xxhash64(file) != xxhash64.getAsLong())); + } + @VisibleForTesting static class MissingShardComparator implements Comparator @@ -287,20 +319,22 @@ private class MissingShardRecovery implements MissingShardRunnable { private final UUID shardUuid; - private final OptionalLong shardSize; + private final long shardSize; + private final OptionalLong shardXxhash64; private final boolean active; - public MissingShardRecovery(UUID shardUuid, OptionalLong shardSize, boolean active) + public MissingShardRecovery(UUID shardUuid, long shardSize, OptionalLong shardXxhash64, boolean active) { this.shardUuid = requireNonNull(shardUuid, "shardUuid is null"); - this.shardSize = requireNonNull(shardSize, "shardSize is null"); + this.shardSize = shardSize; + this.shardXxhash64 = requireNonNull(shardXxhash64, "shardXxhash64 is null"); this.active = active; } @Override public void run() { - restoreFromBackup(shardUuid, shardSize); + restoreFromBackup(shardUuid, shardSize, shardXxhash64); } @Override @@ -313,36 +347,33 @@ public boolean isActive() private static final class MissingShard { private final UUID shardUuid; - private final OptionalLong shardSize; + private final long shardSize; + private final OptionalLong shardXxhash64; private final boolean active; - private MissingShard(UUID shardUuid, OptionalLong shardSize, boolean active) + public MissingShard(UUID shardUuid, long shardSize, OptionalLong shardXxhash64, boolean active) { this.shardUuid = requireNonNull(shardUuid, "shardUuid is null"); - this.shardSize = requireNonNull(shardSize, "shardSize is null"); + this.shardSize = shardSize; + this.shardXxhash64 = requireNonNull(shardXxhash64, "shardXxhash64 is null"); this.active = active; } - public static MissingShard createBackgroundMissingShard(UUID shardUuid, long shardSize) - { - return new MissingShard(shardUuid, OptionalLong.of(shardSize), false); - } - - public static MissingShard createActiveMissingShard(UUID shardUuid) - { - return new MissingShard(shardUuid, OptionalLong.empty(), true); - } - public UUID getShardUuid() { return shardUuid; } - public OptionalLong getShardSize() + public long getShardSize() { return shardSize; } + public OptionalLong getShardXxhash64() + { + return shardXxhash64; + } + public boolean isActive() { return active; @@ -394,6 +425,7 @@ public ListenableFuture load(MissingShard missingShard) MissingShardRecovery task = new MissingShardRecovery( missingShard.getShardUuid(), missingShard.getShardSize(), + missingShard.getShardXxhash64(), missingShard.isActive()); ListenableFuture future = shardRecoveryExecutor.submit(task); future.addListener(() -> queuedMissingShards.invalidate(missingShard), directExecutor()); diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryStats.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryStats.java index 189c2bdbb9873..9c7c8ddd7ca9f 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryStats.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardRecoveryStats.java @@ -35,6 +35,9 @@ public class ShardRecoveryStats private final DistributionStat shardRecoveryTimeInMilliSeconds = new DistributionStat(); private final DistributionStat shardRecoveryBytesPerSecond = new DistributionStat(); + private final CounterStat corruptLocalFile = new CounterStat(); + private final CounterStat corruptRecoveredFile = new CounterStat(); + public void incrementBackgroundShardRecovery() { backgroundShardRecovery.update(1); @@ -67,6 +70,16 @@ public void addShardRecoveryDataRate(DataSize rate, DataSize size, Duration dura shardRecoveryTimeInMilliSeconds.add(duration.toMillis()); } + public void incrementCorruptLocalFile() + { + corruptLocalFile.update(1); + } + + public void incrementCorruptRecoveredFile() + { + corruptRecoveredFile.update(1); + } + @Managed @Nested public CounterStat getActiveShardRecovery() @@ -122,4 +135,18 @@ public DistributionStat getShardRecoveryShardSizeBytes() { return shardRecoveryShardSizeBytes; } + + @Managed + @Nested + public CounterStat getCorruptLocalFile() + { + return corruptLocalFile; + } + + @Managed + @Nested + public CounterStat getCorruptRecoveredFile() + { + return corruptRecoveredFile; + } } diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/StorageService.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/StorageService.java index 5c05867155a1b..988423586908c 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/StorageService.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/StorageService.java @@ -31,5 +31,7 @@ void start() File getStagingFile(UUID shardUuid); + File getQuarantineFile(UUID shardUuid); + Set getStorageShards(); } diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ResultSetValues.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ResultSetValues.java index 437ac475db7b4..48dfbe37f5762 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ResultSetValues.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ResultSetValues.java @@ -29,6 +29,7 @@ import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.slice.Slices.wrappedBuffer; +import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; @@ -52,7 +53,7 @@ public ResultSetValues(List types) this.nulls = new boolean[types.size()]; } - int extractValues(ResultSet resultSet, Set uuidColumns) + int extractValues(ResultSet resultSet, Set uuidColumns, Set hexColumns) throws SQLException { checkArgument(resultSet != null, "resultSet is null"); @@ -88,6 +89,11 @@ else if (javaType == Slice.class) { nulls[i] = resultSet.wasNull(); strings[i] = nulls[i] ? null : uuidFromBytes(bytes).toString().toLowerCase(ENGLISH); } + else if (hexColumns.contains(i)) { + long value = resultSet.getLong(i + 1); + nulls[i] = resultSet.wasNull(); + strings[i] = nulls[i] ? null : format("%016x", value); + } else { String value = resultSet.getString(i + 1); nulls[i] = resultSet.wasNull(); diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ShardMetadataRecordCursor.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ShardMetadataRecordCursor.java index 158d26afc4c88..7ffa24c7efa4b 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ShardMetadataRecordCursor.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/systemtables/ShardMetadataRecordCursor.java @@ -48,6 +48,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkPositionIndex; import static com.google.common.base.Preconditions.checkState; @@ -59,6 +60,7 @@ public class ShardMetadataRecordCursor implements RecordCursor { private static final String SHARD_UUID = "shard_uuid"; + private static final String XXHASH64 = "xxhash64"; private static final String SCHEMA_NAME = "table_schema"; private static final String TABLE_NAME = "table_name"; private static final String MIN_TIMESTAMP = "min_timestamp"; @@ -75,6 +77,7 @@ public class ShardMetadataRecordCursor new ColumnMetadata("uncompressed_size", BIGINT), new ColumnMetadata("compressed_size", BIGINT), new ColumnMetadata("row_count", BIGINT), + new ColumnMetadata(XXHASH64, createVarcharType(16)), new ColumnMetadata(MIN_TIMESTAMP, TIMESTAMP), new ColumnMetadata(MAX_TIMESTAMP, TIMESTAMP))); @@ -107,16 +110,15 @@ public ShardMetadataRecordCursor(IDBI dbi, TupleDomain tupleDomain) this.resultSet = getNextResultSet(); } - private static String constructSqlTemplate(List columnNames, String indexTableName) + private static String constructSqlTemplate(List columnNames, long tableId) { - StringBuilder sql = new StringBuilder(); - sql.append("SELECT\n"); - sql.append(Joiner.on(",\n").join(columnNames)); - sql.append("\nFROM ").append(indexTableName).append(" x\n"); - sql.append("JOIN shards ON (x.shard_id = shards.shard_id)\n"); - sql.append("JOIN tables ON (shards.table_id = tables.table_id)\n"); - - return sql.toString(); + return format("SELECT %s\nFROM %s x\n" + + "JOIN shards ON (x.shard_id = shards.shard_id AND shards.table_id = %s)\n" + + "JOIN tables ON (tables.table_id = %s)\n", + Joiner.on(", ").join(columnNames), + shardIndexTable(tableId), + tableId, + tableId); } private static List createQualifiedColumnNames() @@ -129,6 +131,7 @@ private static List createQualifiedColumnNames() .add("shards" + "." + COLUMNS.get(4).getName()) .add("shards" + "." + COLUMNS.get(5).getName()) .add("shards" + "." + COLUMNS.get(6).getName()) + .add("shards" + "." + COLUMNS.get(7).getName()) .add("min_timestamp") .add("max_timestamp") .build(); @@ -178,7 +181,10 @@ public boolean advanceNextPosition() return false; } } - completedBytes += resultSetValues.extractValues(resultSet, ImmutableSet.of(getColumnIndex(SHARD_METADATA, SHARD_UUID))); + completedBytes += resultSetValues.extractValues( + resultSet, + ImmutableSet.of(getColumnIndex(SHARD_METADATA, SHARD_UUID)), + ImmutableSet.of(getColumnIndex(SHARD_METADATA, XXHASH64))); return true; } catch (SQLException | DBIException e) { @@ -268,7 +274,7 @@ private ResultSet getNextResultSet() connection = dbi.open().getConnection(); statement = PreparedStatementBuilder.create( connection, - constructSqlTemplate(columnNames, shardIndexTable(tableId)), + constructSqlTemplate(columnNames, tableId), columnNames, TYPES, ImmutableSet.of(getColumnIndex(SHARD_METADATA, SHARD_UUID)), diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java index 6c11ee08af8de..083cce2342b33 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/util/DatabaseUtil.java @@ -25,13 +25,16 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.Arrays; import java.util.OptionalInt; import java.util.OptionalLong; import java.util.function.Consumer; +import java.util.function.Predicate; import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_METADATA_ERROR; import static com.google.common.base.Throwables.propagateIfInstanceOf; import static com.google.common.reflect.Reflection.newProxy; +import static com.mysql.jdbc.MysqlErrorNumbers.ER_TRANS_CACHE_FULL; import static java.sql.Types.INTEGER; import static java.util.Objects.requireNonNull; @@ -75,9 +78,14 @@ public static void daoTransaction(IDBI dbi, Class daoType, Consumer ca }); } + public static PrestoException metadataError(Throwable cause, String message) + { + return new PrestoException(RAPTOR_METADATA_ERROR, message, cause); + } + public static PrestoException metadataError(Throwable cause) { - return new PrestoException(RAPTOR_METADATA_ERROR, "Failed to perform metadata operation", cause); + return metadataError(cause, "Failed to perform metadata operation"); } /** @@ -134,6 +142,32 @@ public static boolean isSyntaxOrAccessError(Exception e) return sqlCodeStartsWith(e, "42"); } + public static boolean isTransactionCacheFullError(Exception e) + { + return mySqlErrorCodeMatches(e, ER_TRANS_CACHE_FULL); + } + + /** + * Check if an exception is caused by a MySQL exception of certain error code + */ + private static boolean mySqlErrorCodeMatches(Exception e, int errorCode) + { + return Throwables.getCausalChain(e).stream() + .filter(SQLException.class::isInstance) + .map(SQLException.class::cast) + .filter(t -> t.getErrorCode() == errorCode) + .map(Throwable::getStackTrace) + .anyMatch(isMySQLException()); + } + + private static Predicate isMySQLException() + { + // check if the exception is a mysql exception by matching the package name in the stack trace + return s -> Arrays.stream(s) + .map(StackTraceElement::getClassName) + .anyMatch(t -> t.startsWith("com.mysql.jdbc.")); + } + private static boolean sqlCodeStartsWith(Exception e, String code) { for (Throwable throwable : Throwables.getCausalChain(e)) { diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/util/SyncingFileSystem.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/util/SyncingFileSystem.java index aadb77e5df9d5..420d8052baa73 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/util/SyncingFileSystem.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/util/SyncingFileSystem.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.raptor.util; +import io.airlift.slice.XxHash64; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.Path; @@ -21,10 +22,14 @@ import java.io.BufferedOutputStream; import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; +import static java.util.Objects.requireNonNull; + public final class SyncingFileSystem extends RawLocalFileSystem { @@ -53,12 +58,16 @@ public FSDataOutputStream create(Path path, boolean overwrite, int bufferSize, s private static class LocalFileOutputStream extends OutputStream { + private final byte[] oneByte = new byte[1]; + private final XxHash64 hash = new XxHash64(); + private final File file; private final FileOutputStream out; private boolean closed; private LocalFileOutputStream(File file) throws IOException { + this.file = requireNonNull(file, "file is null"); this.out = new FileOutputStream(file); } @@ -74,6 +83,13 @@ public void close() flush(); out.getFD().sync(); out.close(); + + // extremely paranoid code to detect a broken local file system + try (InputStream in = new FileInputStream(file)) { + if (hash.hash() != XxHash64.hash(in)) { + throw new IOException("File is corrupt after write"); + } + } } @Override @@ -88,13 +104,16 @@ public void write(byte[] b, int off, int len) throws IOException { out.write(b, off, len); + hash.update(b, off, len); } + @SuppressWarnings("NumericCastThatLosesPrecision") @Override public void write(int b) throws IOException { - out.write(b); + oneByte[0] = (byte) (b & 0xFF); + write(oneByte, 0, 1); } } } diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java index 87956089772bb..773460ea33b77 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/TestRaptorIntegrationSmokeTest.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.raptor; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; -import com.facebook.presto.type.ArrayType; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -88,6 +88,30 @@ public void testMapTable() assertUpdate("DROP TABLE map_test"); } + @Test + public void testCreateTableViewAlreadyExists() + throws Exception + { + assertUpdate("CREATE VIEW view_already_exists AS SELECT 1 a"); + assertQueryFails("CREATE TABLE view_already_exists(a integer)", "View already exists: tpch.view_already_exists"); + assertQueryFails("CREATE TABLE View_Already_Exists(a integer)", "View already exists: tpch.view_already_exists"); + assertQueryFails("CREATE TABLE view_already_exists AS SELECT 1 a", "View already exists: tpch.view_already_exists"); + assertQueryFails("CREATE TABLE View_Already_Exists AS SELECT 1 a", "View already exists: tpch.view_already_exists"); + assertUpdate("DROP VIEW view_already_exists"); + } + + @Test + public void testCreateViewTableAlreadyExists() + throws Exception + { + assertUpdate("CREATE TABLE table_already_exists (id integer)"); + assertQueryFails("CREATE VIEW table_already_exists AS SELECT 1 a", "Table already exists: tpch.table_already_exists"); + assertQueryFails("CREATE VIEW Table_Already_Exists AS SELECT 1 a", "Table already exists: tpch.table_already_exists"); + assertQueryFails("CREATE OR REPLACE VIEW table_already_exists AS SELECT 1 a", "Table already exists: tpch.table_already_exists"); + assertQueryFails("CREATE OR REPLACE VIEW Table_Already_Exists AS SELECT 1 a", "Table already exists: tpch.table_already_exists"); + assertUpdate("DROP TABLE table_already_exists"); + } + @Test public void testInsertSelectDecimal() throws Exception diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/backup/TestBackupManager.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/backup/TestBackupManager.java index 33c21277469b8..95509adc4bd55 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/backup/TestBackupManager.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/backup/TestBackupManager.java @@ -13,42 +13,66 @@ */ package com.facebook.presto.raptor.backup; +import com.facebook.presto.raptor.storage.BackupStats; +import com.facebook.presto.raptor.storage.FileStorageService; +import com.facebook.presto.spi.PrestoException; import com.google.common.io.Files; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ThreadLocalRandom; +import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_BACKUP_CORRUPTION; +import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_BACKUP_ERROR; import static com.google.common.io.Files.createTempDir; import static io.airlift.testing.FileUtils.deleteRecursively; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; +import static org.testng.FileAssert.assertFile; @Test(singleThreaded = true) public class TestBackupManager { + private static final UUID FAILURE_UUID = randomUUID(); + private static final UUID CORRUPTION_UUID = randomUUID(); + private File temporary; - private FileBackupStore store; + private BackupStore backupStore; + private FileStorageService storageService; private BackupManager backupManager; - @BeforeClass + @BeforeMethod public void setup() throws Exception { temporary = createTempDir(); - store = new FileBackupStore(new File(temporary, "backup")); - store.start(); - backupManager = new BackupManager(Optional.of(store), 5); + + FileBackupStore fileStore = new FileBackupStore(new File(temporary, "backup")); + fileStore.start(); + backupStore = new TestingBackupStore(fileStore); + + storageService = new FileStorageService(new File(temporary, "data")); + storageService.start(); + + backupManager = new BackupManager(Optional.of(backupStore), storageService, 5); } - @AfterClass(alwaysRun = true) + @AfterMethod(alwaysRun = true) public void tearDown() throws Exception { @@ -60,6 +84,9 @@ public void tearDown() public void testSimple() throws Exception { + assertEmptyStagingDirectory(); + assertBackupStats(0, 0, 0); + List> futures = new ArrayList<>(); List uuids = new ArrayList<>(5); for (int i = 0; i < 5; i++) { @@ -71,7 +98,135 @@ public void testSimple() } futures.forEach(CompletableFuture::join); for (UUID uuid : uuids) { - assertTrue(store.shardExists(uuid)); + assertTrue(backupStore.shardExists(uuid)); + } + + assertBackupStats(5, 0, 0); + assertEmptyStagingDirectory(); + } + + @Test + public void testFailure() + throws Exception + { + assertEmptyStagingDirectory(); + assertBackupStats(0, 0, 0); + + File file = new File(temporary, "failure"); + Files.write("hello world", file, UTF_8); + + try { + backupManager.submit(FAILURE_UUID, file).get(1, SECONDS); + fail("expected exception"); + } + catch (ExecutionException wrapper) { + PrestoException e = (PrestoException) wrapper.getCause(); + assertEquals(e.getErrorCode(), RAPTOR_BACKUP_ERROR.toErrorCode()); + assertEquals(e.getMessage(), "Backup failed for testing"); + } + + assertBackupStats(0, 1, 0); + assertEmptyStagingDirectory(); + } + + @Test + public void testCorruption() + throws Exception + { + assertEmptyStagingDirectory(); + assertBackupStats(0, 0, 0); + + File file = new File(temporary, "corrupt"); + Files.write("hello world", file, UTF_8); + + try { + backupManager.submit(CORRUPTION_UUID, file).get(1, SECONDS); + fail("expected exception"); + } + catch (ExecutionException wrapper) { + PrestoException e = (PrestoException) wrapper.getCause(); + assertEquals(e.getErrorCode(), RAPTOR_BACKUP_CORRUPTION.toErrorCode()); + assertEquals(e.getMessage(), "Backup is corrupt after write: " + CORRUPTION_UUID); + } + + File quarantineBase = storageService.getQuarantineFile(CORRUPTION_UUID); + assertFile(new File(quarantineBase.getPath() + ".original")); + assertFile(new File(quarantineBase.getPath() + ".restored")); + + assertBackupStats(0, 1, 1); + assertEmptyStagingDirectory(); + } + + private void assertEmptyStagingDirectory() + { + File staging = storageService.getStagingFile(randomUUID()).getParentFile(); + assertEquals(staging.list(), new String[] {}); + } + + private void assertBackupStats(int successCount, int failureCount, int corruptionCount) + { + BackupStats stats = backupManager.getStats(); + assertEquals(stats.getBackupSuccess().getTotalCount(), successCount); + assertEquals(stats.getBackupFailure().getTotalCount(), failureCount); + assertEquals(stats.getBackupCorruption().getTotalCount(), corruptionCount); + } + + private static class TestingBackupStore + implements BackupStore + { + private final BackupStore delegate; + + private TestingBackupStore(BackupStore delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + @Override + public void backupShard(UUID uuid, File source) + { + if (uuid.equals(FAILURE_UUID)) { + throw new PrestoException(RAPTOR_BACKUP_ERROR, "Backup failed for testing"); + } + delegate.backupShard(uuid, source); + } + + @Override + public void restoreShard(UUID uuid, File target) + { + delegate.restoreShard(uuid, target); + if (uuid.equals(CORRUPTION_UUID)) { + corruptFile(target); + } + } + + @Override + public boolean deleteShard(UUID uuid) + { + return delegate.deleteShard(uuid); + } + + @Override + public boolean shardExists(UUID uuid) + { + return delegate.shardExists(uuid); + } + + private static void corruptFile(File path) + { + // flip a bit at a random offset + try (RandomAccessFile file = new RandomAccessFile(path, "rw")) { + if (file.length() == 0) { + throw new RuntimeException("file is empty"); + } + long offset = ThreadLocalRandom.current().nextLong(file.length()); + file.seek(offset); + int value = file.read() ^ 0x01; + file.seek(offset); + file.write(value); + } + catch (IOException e) { + throw new RuntimeException(e); + } } } } diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java index 534a12b9407d3..5b9ddcc6a9603 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java @@ -213,8 +213,8 @@ public void testGetNodeBytes() UUID shard1 = UUID.randomUUID(); UUID shard2 = UUID.randomUUID(); List shardNodes = ImmutableList.of( - new ShardInfo(shard1, bucketNumber, ImmutableSet.of("node1"), ImmutableList.of(), 3, 33, 333), - new ShardInfo(shard2, bucketNumber, ImmutableSet.of("node1"), ImmutableList.of(), 5, 55, 555)); + new ShardInfo(shard1, bucketNumber, ImmutableSet.of("node1"), ImmutableList.of(), 3, 33, 333, 0), + new ShardInfo(shard2, bucketNumber, ImmutableSet.of("node1"), ImmutableList.of(), 5, 55, 555, 0)); List columns = ImmutableList.of(new ColumnInfo(1, BIGINT)); shardManager.createTable(tableId, columns, false, OptionalLong.empty()); @@ -710,7 +710,7 @@ public static ShardInfo shardInfo(UUID shardUuid, String nodeIdentifier) public static ShardInfo shardInfo(UUID shardUuid, String nodeId, List columnStats) { - return new ShardInfo(shardUuid, OptionalInt.empty(), ImmutableSet.of(nodeId), columnStats, 0, 0, 0); + return new ShardInfo(shardUuid, OptionalInt.empty(), ImmutableSet.of(nodeId), columnStats, 0, 0, 0, 0); } private static Set toShardNodes(List shards) diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardCleaner.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardCleaner.java index 7e941323d8f7e..9a65a73d9296f 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardCleaner.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardCleaner.java @@ -207,7 +207,7 @@ public void testCleanLocalShards() Set shards = ImmutableSet.of(shard1, shard2, shard3, shard4); for (UUID shard : shards) { - shardDao.insertShard(shard, tableId, null, 0, 0, 0); + shardDao.insertShard(shard, tableId, null, 0, 0, 0, 0); createShardFile(shard); assertTrue(shardFileExists(shard)); } diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardDao.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardDao.java index 80f5689552514..b44f0637b7930 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardDao.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardDao.java @@ -25,7 +25,6 @@ import java.sql.SQLException; import java.util.List; -import java.util.Map; import java.util.OptionalInt; import java.util.OptionalLong; import java.util.UUID; @@ -34,6 +33,7 @@ import static io.airlift.testing.Assertions.assertInstanceOf; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -115,20 +115,19 @@ public void testNodeInsert() public void testInsertShard() { long tableId = createTable("test"); - long shardId = dao.insertShard(UUID.randomUUID(), tableId, null, 13, 42, 84); - - String sql = "SELECT table_id, row_count, compressed_size, uncompressed_size " + - "FROM shards WHERE shard_id = ?"; - List> shards = dbi.withHandle(handle -> handle.select(sql, shardId)); - - assertEquals(shards.size(), 1); - Map shard = shards.get(0); - - assertEquals(shard.get("table_id"), tableId); - assertEquals(shard.get("row_count"), 13L); - assertEquals(shard.get("compressed_size"), 42L); - assertEquals(shard.get("uncompressed_size"), 84L); - } + UUID shardUuid = UUID.randomUUID(); + long shardId = dao.insertShard(shardUuid, tableId, null, 13, 42, 84, 1234); + + ShardMetadata shard = dao.getShard(shardUuid); + assertNotNull(shard); + assertEquals(shard.getTableId(), tableId); + assertEquals(shard.getShardId(), shardId); + assertEquals(shard.getShardUuid(), shardUuid); + assertEquals(shard.getRowCount(), 13); + assertEquals(shard.getCompressedSize(), 42); + assertEquals(shard.getUncompressedSize(), 84); + assertEquals(shard.getXxhash64(), OptionalLong.of(1234)); + } @Test public void testInsertShardNodeUsingShardUuid() @@ -138,7 +137,7 @@ public void testInsertShardNodeUsingShardUuid() long tableId = createTable("test"); UUID shard = UUID.randomUUID(); - dao.insertShard(shard, tableId, null, 0, 0, 0); + dao.insertShard(shard, tableId, null, 0, 0, 0, 0); dao.insertShardNode(shard, nodeId); @@ -177,19 +176,19 @@ public void testNodeShards() long plainTableId = metadataDao.insertTable("test", "plain", false, false, null, 0); long bucketedTableId = metadataDao.insertTable("test", "bucketed", false, false, distributionId, 0); - long shardId1 = dao.insertShard(shardUuid1, plainTableId, null, 1, 11, 111); - long shardId2 = dao.insertShard(shardUuid2, plainTableId, null, 2, 22, 222); - long shardId3 = dao.insertShard(shardUuid3, bucketedTableId, 8, 3, 33, 333); - long shardId4 = dao.insertShard(shardUuid4, bucketedTableId, 9, 4, 44, 444); - long shardId5 = dao.insertShard(shardUuid5, bucketedTableId, 7, 5, 55, 555); + long shardId1 = dao.insertShard(shardUuid1, plainTableId, null, 1, 11, 111, 888_111); + long shardId2 = dao.insertShard(shardUuid2, plainTableId, null, 2, 22, 222, 888_222); + long shardId3 = dao.insertShard(shardUuid3, bucketedTableId, 8, 3, 33, 333, 888_333); + long shardId4 = dao.insertShard(shardUuid4, bucketedTableId, 9, 4, 44, 444, 888_444); + long shardId5 = dao.insertShard(shardUuid5, bucketedTableId, 7, 5, 55, 555, 888_555); OptionalInt noBucket = OptionalInt.empty(); OptionalLong noRange = OptionalLong.empty(); - ShardMetadata shard1 = new ShardMetadata(plainTableId, shardId1, shardUuid1, noBucket, 1, 11, 111, noRange, noRange); - ShardMetadata shard2 = new ShardMetadata(plainTableId, shardId2, shardUuid2, noBucket, 2, 22, 222, noRange, noRange); - ShardMetadata shard3 = new ShardMetadata(bucketedTableId, shardId3, shardUuid3, OptionalInt.of(8), 3, 33, 333, noRange, noRange); - ShardMetadata shard4 = new ShardMetadata(bucketedTableId, shardId4, shardUuid4, OptionalInt.of(9), 4, 44, 444, noRange, noRange); - ShardMetadata shard5 = new ShardMetadata(bucketedTableId, shardId5, shardUuid5, OptionalInt.of(7), 5, 55, 555, noRange, noRange); + ShardMetadata shard1 = new ShardMetadata(plainTableId, shardId1, shardUuid1, noBucket, 1, 11, 111, OptionalLong.of(888_111), noRange, noRange); + ShardMetadata shard2 = new ShardMetadata(plainTableId, shardId2, shardUuid2, noBucket, 2, 22, 222, OptionalLong.of(888_222), noRange, noRange); + ShardMetadata shard3 = new ShardMetadata(bucketedTableId, shardId3, shardUuid3, OptionalInt.of(8), 3, 33, 333, OptionalLong.of(888_333), noRange, noRange); + ShardMetadata shard4 = new ShardMetadata(bucketedTableId, shardId4, shardUuid4, OptionalInt.of(9), 4, 44, 444, OptionalLong.of(888_444), noRange, noRange); + ShardMetadata shard5 = new ShardMetadata(bucketedTableId, shardId5, shardUuid5, OptionalInt.of(7), 5, 55, 555, OptionalLong.of(888_555), noRange, noRange); assertEquals(dao.getShards(plainTableId), ImmutableList.of(shardUuid1, shardUuid2)); assertEquals(dao.getShards(bucketedTableId), ImmutableList.of(shardUuid3, shardUuid4, shardUuid5)); @@ -249,10 +248,10 @@ public void testShardSelection() UUID shardUuid3 = UUID.randomUUID(); UUID shardUuid4 = UUID.randomUUID(); - long shardId1 = dao.insertShard(shardUuid1, tableId, null, 0, 0, 0); - long shardId2 = dao.insertShard(shardUuid2, tableId, null, 0, 0, 0); - long shardId3 = dao.insertShard(shardUuid3, tableId, null, 0, 0, 0); - long shardId4 = dao.insertShard(shardUuid4, tableId, null, 0, 0, 0); + long shardId1 = dao.insertShard(shardUuid1, tableId, null, 0, 0, 0, 0); + long shardId2 = dao.insertShard(shardUuid2, tableId, null, 0, 0, 0, 0); + long shardId3 = dao.insertShard(shardUuid3, tableId, null, 0, 0, 0, 0); + long shardId4 = dao.insertShard(shardUuid4, tableId, null, 0, 0, 0, 0); List shards = dao.getShards(tableId); assertEquals(shards.size(), 4); diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestingShardDao.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestingShardDao.java index d9df8000b0ba2..b94b2c7782bc4 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestingShardDao.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestingShardDao.java @@ -46,8 +46,8 @@ interface TestingShardDao @SqlQuery("SELECT node_identifier FROM nodes") Set getAllNodesInUse(); - @SqlUpdate("INSERT INTO shards (shard_uuid, table_id, bucket_number, create_time, row_count, compressed_size, uncompressed_size)\n" + - "VALUES (:shardUuid, :tableId, :bucketNumber, CURRENT_TIMESTAMP, :rowCount, :compressedSize, :uncompressedSize)") + @SqlUpdate("INSERT INTO shards (shard_uuid, table_id, bucket_number, create_time, row_count, compressed_size, uncompressed_size, xxhash64)\n" + + "VALUES (:shardUuid, :tableId, :bucketNumber, CURRENT_TIMESTAMP, :rowCount, :compressedSize, :uncompressedSize, :xxhash64)") @GetGeneratedKeys long insertShard( @Bind("shardUuid") UUID shardUuid, @@ -55,7 +55,8 @@ long insertShard( @Bind("bucketNumber") Integer bucketNumber, @Bind("rowCount") long rowCount, @Bind("compressedSize") long compressedSize, - @Bind("uncompressedSize") long uncompressedSize); + @Bind("uncompressedSize") long uncompressedSize, + @Bind("xxhash64") long xxhash64); @SqlUpdate("INSERT INTO shard_nodes (shard_id, node_id)\n" + "VALUES (:shardId, :nodeId)\n") diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/OrcTestingUtil.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/OrcTestingUtil.java index 85c9a697961dd..fe5e0b4ea0ee0 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/OrcTestingUtil.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/OrcTestingUtil.java @@ -50,7 +50,7 @@ public static OrcDataSource fileOrcDataSource(File file) public static OrcRecordReader createReader(OrcDataSource dataSource, List columnIds, List types) throws IOException { - OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); List columnNames = orcReader.getColumnNames(); assertEquals(columnNames.size(), columnIds.size()); @@ -69,7 +69,7 @@ public static OrcRecordReader createReader(OrcDataSource dataSource, List public static OrcRecordReader createReaderNoRows(OrcDataSource dataSource) throws IOException { - OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); + OrcReader orcReader = new OrcReader(dataSource, new OrcMetadataReader(), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE), new DataSize(1, Unit.MEGABYTE)); assertEquals(orcReader.getColumnNames().size(), 0); diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestFileStorageService.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestFileStorageService.java index b7f2057d3fbc8..aef84c7decca8 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestFileStorageService.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestFileStorageService.java @@ -70,8 +70,10 @@ public void testFilePaths() UUID uuid = UUID.fromString("701e1a79-74f7-4f56-b438-b41e8e7d019d"); File staging = new File(temporary, format("staging/%s.orc", uuid)); File storage = new File(temporary, format("storage/70/1e/%s.orc", uuid)); + File quarantine = new File(temporary, format("quarantine/%s.orc", uuid)); assertEquals(store.getStagingFile(uuid), staging); assertEquals(store.getStorageFile(uuid), storage); + assertEquals(store.getQuarantineFile(uuid), quarantine); } @Test @@ -80,9 +82,11 @@ public void testStop() { File staging = new File(temporary, "staging"); File storage = new File(temporary, "storage"); + File quarantine = new File(temporary, "quarantine"); assertDirectory(staging); assertDirectory(storage); + assertDirectory(quarantine); File file = store.getStagingFile(randomUUID()); store.createParents(file); @@ -95,6 +99,7 @@ public void testStop() assertFalse(file.exists()); assertFalse(staging.exists()); assertDirectory(storage); + assertDirectory(quarantine); } @Test diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java index c16590d5a396c..cf22ef820fe81 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcFileRewriter.java @@ -13,15 +13,21 @@ */ package com.facebook.presto.raptor.storage; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.orc.OrcDataSource; import com.facebook.presto.orc.OrcRecordReader; import com.facebook.presto.raptor.storage.OrcFileRewriter.OrcFileInfo; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.json.JsonCodec; @@ -80,9 +86,15 @@ public void tearDown() public void testRewrite() throws Exception { + TypeManager typeManager = new TypeRegistry(); + // associate typeManager with a function registry + new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); + ArrayType arrayType = new ArrayType(BIGINT); ArrayType arrayOfArrayType = new ArrayType(arrayType); - MapType mapType = new MapType(createVarcharType(5), BOOLEAN); + Type mapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(createVarcharType(5).getTypeSignature()), + TypeSignatureParameter.of(BOOLEAN.getTypeSignature()))); List columnIds = ImmutableList.of(3L, 7L, 9L, 10L, 11L); List columnTypes = ImmutableList.of(BIGINT, createVarcharType(20), arrayType, mapType, arrayOfArrayType); diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcStorageManager.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcStorageManager.java index 5bbb7082b524f..4329fab7f7e90 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcStorageManager.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestOrcStorageManager.java @@ -56,6 +56,7 @@ import java.io.File; import java.io.IOException; +import java.util.Arrays; import java.util.BitSet; import java.util.Collection; import java.util.List; @@ -68,6 +69,7 @@ import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.raptor.metadata.SchemaDaoUtil.createTablesWithRetry; import static com.facebook.presto.raptor.metadata.TestDatabaseShardManager.createShardManager; +import static com.facebook.presto.raptor.storage.OrcStorageManager.xxhash64; import static com.facebook.presto.raptor.storage.OrcTestingUtil.createReader; import static com.facebook.presto.raptor.storage.OrcTestingUtil.octets; import static com.facebook.presto.spi.type.BigintType.BIGINT; @@ -96,6 +98,7 @@ import static org.joda.time.DateTimeZone.UTC; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; import static org.testng.FileAssert.assertDirectory; @@ -199,6 +202,7 @@ public void testWriter() assertEquals(shardInfo.getRowCount(), 2); assertEquals(shardInfo.getCompressedSize(), file.length()); + assertEquals(shardInfo.getXxhash64(), xxhash64(file)); // verify primary and backup shard exist assertFile(file, "primary shard"); @@ -211,7 +215,7 @@ public void testWriter() assertTrue(file.getParentFile().delete()); assertFalse(file.exists()); - recoveryManager.restoreFromBackup(shardUuid, OptionalLong.empty()); + recoveryManager.restoreFromBackup(shardUuid, shardInfo.getCompressedSize(), OptionalLong.of(shardInfo.getXxhash64())); try (OrcDataSource dataSource = manager.openShard(shardUuid, READER_ATTRIBUTES)) { OrcRecordReader reader = createReader(dataSource, columnIds, columnTypes); @@ -377,13 +381,18 @@ public void testWriterRollback() // verify shard exists in staging String[] files = staging.list(); - assertEquals(files.length, 1); - assertTrue(files[0].endsWith(".orc")); + assertNotNull(files); + String stagingFile = Arrays.stream(files) + .filter(file -> file.endsWith(".orc")) + .findFirst() + .orElseThrow(() -> new AssertionError("file not found in staging")); // rollback should cleanup staging files sink.rollback(); - assertEquals(staging.list(), new String[] {}); + files = staging.list(); + assertNotNull(files); + assertTrue(Arrays.stream(files).noneMatch(stagingFile::equals)); } @Test @@ -601,7 +610,7 @@ public static OrcStorageManager createOrcStorageManager( storageService, backupStore, READER_ATTRIBUTES, - new BackupManager(backupStore, 1), + new BackupManager(backupStore, storageService, 1), recoveryManager, shardRecorder, new TypeRegistry(), diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardEjector.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardEjector.java index 452aa7bf482ca..44e88716bab02 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardEjector.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardEjector.java @@ -182,7 +182,7 @@ private static Set uuids(Set metadata) private static ShardInfo shardInfo(String node, long size) { - return new ShardInfo(randomUUID(), OptionalInt.empty(), ImmutableSet.of(node), ImmutableList.of(), 1, size, size * 2); + return new ShardInfo(randomUUID(), OptionalInt.empty(), ImmutableSet.of(node), ImmutableList.of(), 1, size, size * 2, 0); } private static NodeManager createNodeManager(String current, String... others) diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardRecovery.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardRecovery.java index 797b18f63aada..34ba2c7481b3a 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardRecovery.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardRecovery.java @@ -18,6 +18,7 @@ import com.facebook.presto.raptor.metadata.ShardManager; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.testing.TestingNodeManager; +import com.google.common.collect.ImmutableList; import com.google.common.io.Files; import io.airlift.units.Duration; import org.skife.jdbi.v2.DBI; @@ -28,12 +29,16 @@ import org.testng.annotations.Test; import java.io.File; +import java.util.List; import java.util.Optional; import java.util.OptionalLong; import java.util.UUID; +import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_BACKUP_CORRUPTION; import static com.facebook.presto.raptor.metadata.SchemaDaoUtil.createTablesWithRetry; import static com.facebook.presto.raptor.metadata.TestDatabaseShardManager.createShardManager; +import static com.facebook.presto.raptor.storage.OrcStorageManager.xxhash64; +import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.io.Files.createTempDir; import static io.airlift.testing.FileUtils.deleteRecursively; import static java.io.File.createTempFile; @@ -42,7 +47,9 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; @Test(singleThreaded = true) public class TestShardRecovery @@ -100,45 +107,141 @@ public void testShardRecovery() assertEquals(backupFile.length(), tempFile.length()); assertFalse(file.exists()); - recoveryManager.restoreFromBackup(shardUuid, OptionalLong.empty()); + recoveryManager.restoreFromBackup(shardUuid, tempFile.length(), OptionalLong.empty()); assertTrue(file.exists()); assertEquals(file.length(), tempFile.length()); } - @SuppressWarnings("EmptyTryBlock") @Test - public void testShardRecoveryExistingFileMismatch() + public void testShardRecoveryExistingFileSizeMismatch() throws Exception { UUID shardUuid = UUID.randomUUID(); - File file = storageService.getStorageFile(shardUuid); - storageService.createParents(file); + + // write data and backup File tempFile = createTempFile("tmp", null, temporary); + Files.write("test data", tempFile, UTF_8); + + backupStore.backupShard(shardUuid, tempFile); + assertTrue(backupStore.shardExists(shardUuid)); + + File backupFile = backupStore.getBackupFile(shardUuid); + assertTrue(Files.equal(tempFile, backupFile)); + + // write corrupt storage file with wrong length + File storageFile = storageService.getStorageFile(shardUuid); + storageService.createParents(storageFile); + + Files.write("bad data", storageFile, UTF_8); + + assertTrue(storageFile.exists()); + assertNotEquals(storageFile.length(), tempFile.length()); + assertFalse(Files.equal(storageFile, tempFile)); + + // restore from backup and verify + recoveryManager.restoreFromBackup(shardUuid, tempFile.length(), OptionalLong.empty()); + assertTrue(storageFile.exists()); + assertTrue(Files.equal(storageFile, tempFile)); + + // verify quarantine exists + List quarantined = listFiles(storageService.getQuarantineFile(shardUuid).getParentFile()); + assertEquals(quarantined.size(), 1); + assertTrue(getOnlyElement(quarantined).startsWith(shardUuid + ".orc.corrupt")); + } + + @Test + public void testShardRecoveryExistingFileChecksumMismatch() + throws Exception + { + UUID shardUuid = UUID.randomUUID(); + + // write data and backup + File tempFile = createTempFile("tmp", null, temporary); Files.write("test data", tempFile, UTF_8); - Files.write("bad data", file, UTF_8); backupStore.backupShard(shardUuid, tempFile); + assertTrue(backupStore.shardExists(shardUuid)); + + File backupFile = backupStore.getBackupFile(shardUuid); + assertTrue(Files.equal(tempFile, backupFile)); + + // write corrupt storage file with wrong data + File storageFile = storageService.getStorageFile(shardUuid); + storageService.createParents(storageFile); - long backupSize = tempFile.length(); + Files.write("test xata", storageFile, UTF_8); + + assertTrue(storageFile.exists()); + assertEquals(storageFile.length(), tempFile.length()); + assertFalse(Files.equal(storageFile, tempFile)); + + // restore from backup and verify + recoveryManager.restoreFromBackup(shardUuid, tempFile.length(), OptionalLong.of(xxhash64(tempFile))); + + assertTrue(storageFile.exists()); + assertTrue(Files.equal(storageFile, tempFile)); + + // verify quarantine exists + List quarantined = listFiles(storageService.getQuarantineFile(shardUuid).getParentFile()); + assertEquals(quarantined.size(), 1); + assertTrue(getOnlyElement(quarantined).startsWith(shardUuid + ".orc.corrupt")); + } + + @Test + public void testShardRecoveryBackupChecksumMismatch() + throws Exception + { + UUID shardUuid = UUID.randomUUID(); + + // write storage file + File storageFile = storageService.getStorageFile(shardUuid); + storageService.createParents(storageFile); + + Files.write("test data", storageFile, UTF_8); + + long size = storageFile.length(); + long xxhash64 = xxhash64(storageFile); + + // backup and verify + backupStore.backupShard(shardUuid, storageFile); assertTrue(backupStore.shardExists(shardUuid)); - assertEquals(backupStore.getBackupFile(shardUuid).length(), backupSize); + File backupFile = backupStore.getBackupFile(shardUuid); + assertTrue(Files.equal(storageFile, backupFile)); - assertTrue(file.exists()); - assertNotEquals(file.length(), backupSize); + // corrupt backup file + Files.write("test xata", backupFile, UTF_8); - recoveryManager.restoreFromBackup(shardUuid, OptionalLong.of(backupSize)); + assertTrue(backupFile.exists()); + assertEquals(storageFile.length(), backupFile.length()); + assertFalse(Files.equal(storageFile, backupFile)); - assertTrue(file.exists()); - assertEquals(file.length(), backupSize); + // delete local file to force restore + assertTrue(storageFile.delete()); + assertFalse(storageFile.exists()); + + // restore should fail + try { + recoveryManager.restoreFromBackup(shardUuid, size, OptionalLong.of(xxhash64)); + fail("expected exception"); + } + catch (PrestoException e) { + assertEquals(e.getErrorCode(), RAPTOR_BACKUP_CORRUPTION.toErrorCode()); + assertEquals(e.getMessage(), "Backup is corrupt after read: " + shardUuid); + } + + // verify quarantine exists + List quarantined = listFiles(storageService.getQuarantineFile(shardUuid).getParentFile()); + assertEquals(quarantined.size(), 1); + assertTrue(getOnlyElement(quarantined).startsWith(shardUuid + ".orc.corrupt")); } @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "No backup file found for shard: .*") public void testNoBackupException() throws Exception { - recoveryManager.restoreFromBackup(UUID.randomUUID(), OptionalLong.empty()); + recoveryManager.restoreFromBackup(UUID.randomUUID(), 0, OptionalLong.empty()); } public static ShardRecoveryManager createShardRecoveryManager( @@ -154,4 +257,11 @@ public static ShardRecoveryManager createShardRecoveryManager( new Duration(5, MINUTES), 10); } + + private static List listFiles(File path) + { + String[] files = path.list(); + assertNotNull(files); + return ImmutableList.copyOf(files); + } } diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java index ba377fac87aba..45a257521232b 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/storage/TestShardWriter.java @@ -14,14 +14,20 @@ package com.facebook.presto.raptor.storage; import com.facebook.presto.RowPagesBuilder; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.orc.OrcDataSource; import com.facebook.presto.orc.OrcRecordReader; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.json.JsonCodec; @@ -77,10 +83,16 @@ public void tearDown() public void testWriter() throws Exception { + TypeManager typeManager = new TypeRegistry(); + // associate typeManager with a function registry + new FunctionRegistry(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig()); + List columnIds = ImmutableList.of(1L, 2L, 4L, 6L, 7L, 8L, 9L, 10L); ArrayType arrayType = new ArrayType(BIGINT); ArrayType arrayOfArrayType = new ArrayType(arrayType); - MapType mapType = new MapType(createVarcharType(10), BOOLEAN); + Type mapType = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(createVarcharType(10).getTypeSignature()), + TypeSignatureParameter.of(BOOLEAN.getTypeSignature()))); List columnTypes = ImmutableList.of(BIGINT, createVarcharType(10), VARBINARY, DOUBLE, BOOLEAN, arrayType, mapType, arrayOfArrayType); File file = new File(directory, System.nanoTime() + ".orc"); @@ -151,7 +163,9 @@ public void testWriter() assertEquals(column6.getPositionCount(), 3); assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column6, 0), mapBlockOf(createVarcharType(5), BOOLEAN, "k1", true))); - assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column6, 1), mapBlockOf(createVarcharType(5), BOOLEAN, "k2", null))); + Block object = arrayType.getObject(column6, 1); + Block k2 = mapBlockOf(createVarcharType(5), BOOLEAN, "k2", null); + assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, object, k2)); assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column6, 2), mapBlockOf(createVarcharType(5), BOOLEAN, "k3", false))); Block column7 = reader.readBlock(arrayOfArrayType, 7); diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/systemtables/TestShardMetadataRecordCursor.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/systemtables/TestShardMetadataRecordCursor.java index eb5aa9d217c81..82876a25f9268 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/systemtables/TestShardMetadataRecordCursor.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/systemtables/TestShardMetadataRecordCursor.java @@ -104,9 +104,9 @@ public void testSimple() UUID uuid1 = UUID.randomUUID(); UUID uuid2 = UUID.randomUUID(); UUID uuid3 = UUID.randomUUID(); - ShardInfo shardInfo1 = new ShardInfo(uuid1, bucketNumber, ImmutableSet.of("node1"), ImmutableList.of(), 1, 10, 100); - ShardInfo shardInfo2 = new ShardInfo(uuid2, bucketNumber, ImmutableSet.of("node2"), ImmutableList.of(), 2, 20, 200); - ShardInfo shardInfo3 = new ShardInfo(uuid3, bucketNumber, ImmutableSet.of("node3"), ImmutableList.of(), 3, 30, 300); + ShardInfo shardInfo1 = new ShardInfo(uuid1, bucketNumber, ImmutableSet.of("node1"), ImmutableList.of(), 1, 10, 100, 0x1234); + ShardInfo shardInfo2 = new ShardInfo(uuid2, bucketNumber, ImmutableSet.of("node2"), ImmutableList.of(), 2, 20, 200, 0xCAFEBABEDEADBEEFL); + ShardInfo shardInfo3 = new ShardInfo(uuid3, bucketNumber, ImmutableSet.of("node3"), ImmutableList.of(), 3, 30, 300, 0xFEDCBA0987654321L); List shards = ImmutableList.of(shardInfo1, shardInfo2, shardInfo3); long transactionId = shardManager.beginTransaction(); @@ -130,8 +130,8 @@ public void testSimple() ImmutableMap.builder() .put(0, Domain.singleValue(createVarcharType(10), schema)) .put(1, Domain.create(ValueSet.ofRanges(lessThanOrEqual(createVarcharType(10), table)), true)) - .put(6, Domain.create(ValueSet.ofRanges(lessThanOrEqual(BIGINT, date1.getMillis()), greaterThan(BIGINT, date2.getMillis())), true)) - .put(7, Domain.create(ValueSet.ofRanges(lessThanOrEqual(BIGINT, date1.getMillis()), greaterThan(BIGINT, date2.getMillis())), true)) + .put(8, Domain.create(ValueSet.ofRanges(lessThanOrEqual(BIGINT, date1.getMillis()), greaterThan(BIGINT, date2.getMillis())), true)) + .put(9, Domain.create(ValueSet.ofRanges(lessThanOrEqual(BIGINT, date1.getMillis()), greaterThan(BIGINT, date2.getMillis())), true)) .build()); List actual; @@ -141,9 +141,9 @@ public void testSimple() assertEquals(actual.size(), 3); List expected = ImmutableList.of( - new MaterializedRow(DEFAULT_PRECISION, schema, table, utf8Slice(uuid1.toString()), null, 100L, 10L, 1L, null, null), - new MaterializedRow(DEFAULT_PRECISION, schema, table, utf8Slice(uuid2.toString()), null, 200L, 20L, 2L, null, null), - new MaterializedRow(DEFAULT_PRECISION, schema, table, utf8Slice(uuid3.toString()), null, 300L, 30L, 3L, null, null)); + new MaterializedRow(DEFAULT_PRECISION, schema, table, utf8Slice(uuid1.toString()), null, 100L, 10L, 1L, utf8Slice("0000000000001234"), null, null), + new MaterializedRow(DEFAULT_PRECISION, schema, table, utf8Slice(uuid2.toString()), null, 200L, 20L, 2L, utf8Slice("cafebabedeadbeef"), null, null), + new MaterializedRow(DEFAULT_PRECISION, schema, table, utf8Slice(uuid3.toString()), null, 300L, 30L, 3L, utf8Slice("fedcba0987654321"), null, null)); assertEquals(actual, expected); } diff --git a/presto-rcfile/pom.xml b/presto-rcfile/pom.xml index 2103d7c2fbf1a..a182d53b72a88 100644 --- a/presto-rcfile/pom.xml +++ b/presto-rcfile/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-rcfile diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileDecoderUtils.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileDecoderUtils.java index 055e510613f36..0def782f57737 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileDecoderUtils.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileDecoderUtils.java @@ -238,7 +238,12 @@ public static int calculateTruncationLength(Type type, Slice slice, int offset, return calculateTruncationLength(((VarcharType) type).getLength(), slice, offset, length); } if (type instanceof CharType) { - return calculateTruncationLength(((CharType) type).getLength(), slice, offset, length); + int truncationLength = calculateTruncationLength(((CharType) type).getLength(), slice, offset, length); + // At run-time char(x) is represented without trailing spaces + while (truncationLength > 0 && slice.getByte(offset + truncationLength - 1) == ' ') { + truncationLength--; + } + return truncationLength; } return length; } diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java index 684f67a3ace64..20625063b236b 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/RcFileReader.java @@ -137,21 +137,19 @@ private RcFileReader( this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(readColumns)); - checkArgument(offset >= 0, "offset is negative"); - checkArgument(offset < dataSource.getSize(), "offset is greater than data size"); - checkArgument(length >= 1, "length must be at least 1"); + verify(offset >= 0, "offset is negative"); + verify(offset < dataSource.getSize(), "offset is greater than data size"); + verify(length >= 1, "length must be at least 1"); this.length = length; this.end = offset + length; - checkArgument(end <= dataSource.getSize(), "offset plus length is greater than data size"); + verify(end <= dataSource.getSize(), "offset plus length is greater than data size"); // read header Slice magic = input.readSlice(RCFILE_MAGIC.length()); boolean compressed; if (RCFILE_MAGIC.equals(magic)) { version = input.readByte(); - if (version > CURRENT_VERSION) { - throw corrupt("RCFile version %s not supported: %s", version, dataSource); - } + verify(version <= CURRENT_VERSION, "RCFile version %s not supported: %s", version, dataSource); validateWrite(validation -> validation.getVersion() == version, "Unexpected file version"); compressed = input.readBoolean(); } @@ -160,19 +158,14 @@ else if (SEQUENCE_FILE_MAGIC.equals(magic)) { // first version of RCFile used magic SEQ with version 6 byte sequenceFileVersion = input.readByte(); - if (sequenceFileVersion != SEQUENCE_FILE_VERSION) { - throw corrupt("File %s is a SequenceFile not an RCFile", dataSource); - } + verify(sequenceFileVersion == SEQUENCE_FILE_VERSION, "File %s is a SequenceFile not an RCFile", dataSource); // this is the first version of RCFile this.version = FIRST_VERSION; Slice keyClassName = readLengthPrefixedString(input); Slice valueClassName = readLengthPrefixedString(input); - if (!RCFILE_KEY_BUFFER_NAME.equals(keyClassName) || !RCFILE_VALUE_BUFFER_NAME.equals(valueClassName)) { - throw corrupt("File %s is a SequenceFile not an RCFile", dataSource); - } - + verify(RCFILE_KEY_BUFFER_NAME.equals(keyClassName) && RCFILE_VALUE_BUFFER_NAME.equals(valueClassName), "File %s is a SequenceFile not an RCFile", dataSource); compressed = input.readBoolean(); // RC file is never block compressed @@ -197,12 +190,8 @@ else if (SEQUENCE_FILE_MAGIC.equals(magic)) { // read metadata int metadataEntries = Integer.reverseBytes(input.readInt()); - if (metadataEntries < 0) { - throw corrupt("Invalid metadata entry count %s in RCFile %s", metadataEntries, dataSource); - } - if (metadataEntries > MAX_METADATA_ENTRIES) { - throw corrupt("Too many metadata entries (%s) in RCFile %s", metadataEntries, dataSource); - } + verify(metadataEntries >= 0, "Invalid metadata entry count %s in RCFile %s", metadataEntries, dataSource); + verify(metadataEntries <= MAX_METADATA_ENTRIES, "Too many metadata entries (%s) in RCFile %s", metadataEntries, dataSource); ImmutableMap.Builder metadataBuilder = ImmutableMap.builder(); for (int i = 0; i < metadataEntries; i++) { metadataBuilder.put(readLengthPrefixedString(input).toStringUtf8(), readLengthPrefixedString(input).toStringUtf8()); @@ -220,9 +209,7 @@ else if (SEQUENCE_FILE_MAGIC.equals(magic)) { } // initialize columns - if (columnCount > MAX_COLUMN_COUNT) { - throw corrupt("Too many columns (%s) in RCFile %s", columnCountString, dataSource); - } + verify(columnCount <= MAX_COLUMN_COUNT, "Too many columns (%s) in RCFile %s", columnCountString, dataSource); columns = new Column[columnCount]; for (Entry entry : readColumns.entrySet()) { if (entry.getKey() < columnCount) { @@ -339,16 +326,12 @@ public int advance() } // read uncompressed size of row group (which is useless information) - if (input.remaining() < SIZE_OF_INT) { - throw corrupt("RCFile truncated %s", dataSource); - } + verify(input.remaining() >= SIZE_OF_INT, "RCFile truncated %s", dataSource); int unusedRowGroupSize = Integer.reverseBytes(input.readInt()); // read sequence sync if present if (unusedRowGroupSize == -1) { - if (input.remaining() < SIZE_OF_LONG + SIZE_OF_LONG + SIZE_OF_INT) { - throw corrupt("RCFile truncated %s", dataSource); - } + verify(input.remaining() >= SIZE_OF_LONG + SIZE_OF_LONG + SIZE_OF_INT, "RCFile truncated %s", dataSource); // The full sync sequence is "0xFFFFFFFF syncFirst syncSecond". If // this sequence begins in our segment, we must continue process until the @@ -361,9 +344,7 @@ public int advance() return -1; } - if (syncFirst != input.readLong() || syncSecond != input.readLong()) { - throw corrupt("Invalid sync in RCFile %s", dataSource); - } + verify(syncFirst == input.readLong() && syncSecond == input.readLong(), "Invalid sync in RCFile %s", dataSource); // read the useless uncompressed length unusedRowGroupSize = Integer.reverseBytes(input.readInt()); @@ -371,9 +352,7 @@ public int advance() else if (rowsRead > 0) { validateWrite(writeValidation -> false, "Expected sync sequence for every row group except the first one"); } - if (unusedRowGroupSize <= 0) { - throw corrupt("Invalid uncompressed row group length %s", unusedRowGroupSize); - } + verify(unusedRowGroupSize > 0, "Invalid uncompressed row group length %s", unusedRowGroupSize); // read row group header int uncompressedHeaderSize = Integer.reverseBytes(input.readInt()); @@ -396,9 +375,7 @@ else if (rowsRead > 0) { header = buffer; } else { - if (compressedHeaderSize != uncompressedHeaderSize) { - throw corrupt("Invalid RCFile %s", dataSource); - } + verify(compressedHeaderSize == uncompressedHeaderSize, "Invalid RCFile %s", dataSource); header = compressedHeaderBuffer; } BasicSliceInput headerInput = header.getInput(); @@ -433,9 +410,7 @@ else if (rowsRead > 0) { } // this value is not used but validate it is correct since it might signal corruption - if (unusedRowGroupSize != totalCompressedDataSize + uncompressedHeaderSize) { - throw corrupt("Invalid row group size"); - } + verify(unusedRowGroupSize == totalCompressedDataSize + uncompressedHeaderSize, "Invalid row group size"); validateWriteRowGroupChecksum(); validateWritePageChecksum(); @@ -481,13 +456,18 @@ private Slice readLengthPrefixedString(SliceInput in) throws RcFileCorruptionException { int length = toIntExact(readVInt(in)); - if (length > MAX_METADATA_STRING_LENGTH) { - throw corrupt("Metadata string value is too long (%s) in RCFile %s", length, in); - } - + verify(length <= MAX_METADATA_STRING_LENGTH, "Metadata string value is too long (%s) in RCFile %s", length, in); return in.readSlice(length); } + private void verify(boolean expression, String messageFormat, Object... args) + throws RcFileCorruptionException + { + if (!expression) { + throw corrupt(messageFormat, args); + } + } + private RcFileCorruptionException corrupt(String messageFormat, Object... args) { closeQuietly(); diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/DateEncoding.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/DateEncoding.java index 2cf7025554b07..a190f4ffa8024 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/DateEncoding.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/DateEncoding.java @@ -19,13 +19,11 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.ISODateTimeFormat; -import java.io.IOException; import java.util.concurrent.TimeUnit; import static java.lang.Math.toIntExact; @@ -68,17 +66,12 @@ public void encodeValueInto(int depth, Block block, int position, SliceOutput ou private void encodeValue(Block block, int position, SliceOutput output) { - try { - long days = type.getLong(block, position); - long millis = TimeUnit.DAYS.toMillis(days); - buffer.setLength(0); - HIVE_DATE_PARSER.printTo(buffer, millis); - for (int index = 0; index < buffer.length(); index++) { - output.writeByte(buffer.charAt(index)); - } - } - catch (IOException e) { - throw Throwables.propagate(e); + long days = type.getLong(block, position); + long millis = TimeUnit.DAYS.toMillis(days); + buffer.setLength(0); + HIVE_DATE_PARSER.printTo(buffer, millis); + for (int index = 0; index < buffer.length(); index++) { + output.writeByte(buffer.charAt(index)); } } diff --git a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/TimestampEncoding.java b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/TimestampEncoding.java index 57b7afab548ff..c1bf5eba2129f 100644 --- a/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/TimestampEncoding.java +++ b/presto-rcfile/src/main/java/com/facebook/presto/rcfile/text/TimestampEncoding.java @@ -19,7 +19,6 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; -import com.google.common.base.Throwables; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import org.joda.time.DateTimeZone; @@ -29,8 +28,6 @@ import org.joda.time.format.DateTimeParser; import org.joda.time.format.DateTimePrinter; -import java.io.IOException; - public class TimestampEncoding implements TextColumnEncoding { @@ -66,40 +63,30 @@ public TimestampEncoding(Type type, Slice nullSequence, DateTimeZone hiveStorage @Override public void encodeColumn(Block block, SliceOutput output, EncodeOutput encodeOutput) { - try { - for (int position = 0; position < block.getPositionCount(); position++) { - if (block.isNull(position)) { - output.writeBytes(nullSequence); - } - else { - long millis = type.getLong(block, position); - buffer.setLength(0); - dateTimeFormatter.printTo(buffer, millis); - for (int index = 0; index < buffer.length(); index++) { - output.writeByte(buffer.charAt(index)); - } + for (int position = 0; position < block.getPositionCount(); position++) { + if (block.isNull(position)) { + output.writeBytes(nullSequence); + } + else { + long millis = type.getLong(block, position); + buffer.setLength(0); + dateTimeFormatter.printTo(buffer, millis); + for (int index = 0; index < buffer.length(); index++) { + output.writeByte(buffer.charAt(index)); } - encodeOutput.closeEntry(); } - } - catch (IOException e) { - throw Throwables.propagate(e); + encodeOutput.closeEntry(); } } @Override public void encodeValueInto(int depth, Block block, int position, SliceOutput output) { - try { - long millis = type.getLong(block, position); - buffer.setLength(0); - dateTimeFormatter.printTo(buffer, millis); - for (int index = 0; index < buffer.length(); index++) { - output.writeByte(buffer.charAt(index)); - } - } - catch (IOException e) { - throw Throwables.propagate(e); + long millis = type.getLong(block, position); + buffer.setLength(0); + dateTimeFormatter.printTo(buffer, millis); + for (int index = 0; index < buffer.length(); index++) { + output.writeByte(buffer.charAt(index)); } } diff --git a/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java b/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java index 2f0bb03d964b7..9a01b72a18515 100644 --- a/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java +++ b/presto-rcfile/src/test/java/com/facebook/presto/rcfile/RcFileTester.java @@ -13,24 +13,31 @@ */ package com.facebook.presto.rcfile; +import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.hadoop.HadoopNative; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.rcfile.binary.BinaryRcFileEncoding; import com.facebook.presto.rcfile.text.TextRcFileEncoding; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.SqlVarbinary; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.spi.type.VarcharType; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; -import com.facebook.presto.type.RowType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.type.TypeRegistry; import com.google.common.base.Throwables; import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; @@ -179,7 +186,11 @@ @SuppressWarnings("StaticPseudoFunctionalStyleMethod") public class RcFileTester { + private static final TypeManager TYPE_MANAGER = new TypeRegistry(); static { + // associate TYPE_MANAGER with a function registry + new FunctionRegistry(TYPE_MANAGER, new BlockEncodingManager(TYPE_MANAGER), new FeaturesConfig()); + HadoopNative.requireHadoopNative(); } @@ -1161,7 +1172,9 @@ private static Object toHiveStruct(Object input) private static MapType createMapType(Type type) { - return new MapType(type, type); + return (MapType) TYPE_MANAGER.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(type.getTypeSignature()), + TypeSignatureParameter.of(type.getTypeSignature()))); } private static Object toHiveMap(Object nullKeyValue, Object input) diff --git a/presto-record-decoder/pom.xml b/presto-record-decoder/pom.xml index e7a37032e744b..c4998ff907f83 100644 --- a/presto-record-decoder/pom.xml +++ b/presto-record-decoder/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-record-decoder diff --git a/presto-redis/pom.xml b/presto-redis/pom.xml index 38e5d1eaa59f2..66940746c2481 100644 --- a/presto-redis/pom.xml +++ b/presto-redis/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-redis diff --git a/presto-resource-group-managers/pom.xml b/presto-resource-group-managers/pom.xml index ddd0c97a74063..a00b16abb83ae 100644 --- a/presto-resource-group-managers/pom.xml +++ b/presto-resource-group-managers/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-resource-group-managers diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java index cb8aaceadf38a..df9093323eab6 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/AbstractResourceConfigurationManager.java @@ -48,7 +48,7 @@ public abstract class AbstractResourceConfigurationManager @GuardedBy("generalPoolMemoryFraction") private long generalPoolBytes; - protected abstract Optional getCpuQuotaPeriodMillis(); + protected abstract Optional getCpuQuotaPeriod(); protected abstract List getRootGroups(); protected void validateRootGroups(ManagerSpec managerSpec) @@ -177,18 +177,25 @@ protected void configureGroup(ResourceGroup group, ResourceGroupSpec match) } group.setMaxQueuedQueries(match.getMaxQueued()); group.setMaxRunningQueries(match.getMaxRunning()); + if (match.getQueuedTimeLimit().isPresent()) { + group.setQueuedTimeLimit(match.getQueuedTimeLimit().get()); + } + if (match.getRunningTimeLimit().isPresent()) { + group.setRunningTimeLimit(match.getRunningTimeLimit().get()); + } if (match.getSchedulingPolicy().isPresent()) { group.setSchedulingPolicy(match.getSchedulingPolicy().get()); } if (match.getSchedulingWeight().isPresent()) { group.setSchedulingWeight(match.getSchedulingWeight().get()); } - if (match.getJmxExport().isPresent()) { + // if the new and current values do not differ an exception is thrown + if (match.getJmxExport().isPresent() && match.getJmxExport().get() != group.getJmxExport()) { group.setJmxExport(match.getJmxExport().get()); } if (match.getSoftCpuLimit().isPresent() || match.getHardCpuLimit().isPresent()) { // This will never throw an exception if the validateManagerSpec method succeeds - checkState(getCpuQuotaPeriodMillis().isPresent(), "Must specify hard CPU limit in addition to soft limit"); + checkState(getCpuQuotaPeriod().isPresent(), "Must specify hard CPU limit in addition to soft limit"); Duration limit; if (match.getHardCpuLimit().isPresent()) { limit = match.getHardCpuLimit().get(); @@ -196,7 +203,7 @@ protected void configureGroup(ResourceGroup group, ResourceGroupSpec match) else { limit = match.getSoftCpuLimit().get(); } - long rate = (long) Math.min(1000.0 * limit.toMillis() / (double) getCpuQuotaPeriodMillis().get().toMillis(), Long.MAX_VALUE); + long rate = (long) Math.min(1000.0 * limit.toMillis() / (double) getCpuQuotaPeriod().get().toMillis(), Long.MAX_VALUE); rate = Math.max(1, rate); group.setCpuQuotaGenerationMillisPerSecond(rate); } diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/FileResourceGroupConfigurationManager.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/FileResourceGroupConfigurationManager.java index 152cda18fb99d..34d831ad42986 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/FileResourceGroupConfigurationManager.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/FileResourceGroupConfigurationManager.java @@ -37,7 +37,7 @@ public class FileResourceGroupConfigurationManager { private final List rootGroups; private final List selectors; - private final Optional cpuQuotaPeriodMillis; + private final Optional cpuQuotaPeriod; @Inject public FileResourceGroupConfigurationManager(ClusterMemoryPoolManager memoryPoolManager, FileResourceGroupConfig config, JsonCodec codec) @@ -54,15 +54,15 @@ public FileResourceGroupConfigurationManager(ClusterMemoryPoolManager memoryPool throw Throwables.propagate(e); } this.rootGroups = managerSpec.getRootGroups(); - this.cpuQuotaPeriodMillis = managerSpec.getCpuQuotaPeriod(); + this.cpuQuotaPeriod = managerSpec.getCpuQuotaPeriod(); validateRootGroups(managerSpec); this.selectors = buildSelectors(managerSpec); } @Override - protected Optional getCpuQuotaPeriodMillis() + protected Optional getCpuQuotaPeriod() { - return cpuQuotaPeriodMillis; + return cpuQuotaPeriod; } @Override diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/ResourceGroupSpec.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/ResourceGroupSpec.java index 16152268b08dc..abc35daa46d72 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/ResourceGroupSpec.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/ResourceGroupSpec.java @@ -47,6 +47,8 @@ public class ResourceGroupSpec private final Optional jmxExport; private final Optional softCpuLimit; private final Optional hardCpuLimit; + private final Optional queuedTimeLimit; + private final Optional runningTimeLimit; @JsonCreator public ResourceGroupSpec( @@ -59,11 +61,15 @@ public ResourceGroupSpec( @JsonProperty("subGroups") Optional> subGroups, @JsonProperty("jmxExport") Optional jmxExport, @JsonProperty("softCpuLimit") Optional softCpuLimit, - @JsonProperty("hardCpuLimit") Optional hardCpuLimit) + @JsonProperty("hardCpuLimit") Optional hardCpuLimit, + @JsonProperty("queuedTimeLimit") Optional queuedTimeLimit, + @JsonProperty("runningTimeLimit") Optional runningTimeLimit) { this.softCpuLimit = requireNonNull(softCpuLimit, "softCpuLimit is null"); this.hardCpuLimit = requireNonNull(hardCpuLimit, "hardCpuLimit is null"); this.jmxExport = requireNonNull(jmxExport, "jmxExport is null"); + this.queuedTimeLimit = requireNonNull(queuedTimeLimit, "queuedTimeLimit is null"); + this.runningTimeLimit = requireNonNull(runningTimeLimit, "runningTimeLimit is null"); this.name = requireNonNull(name, "name is null"); checkArgument(maxQueued >= 0, "maxQueued is negative"); this.maxQueued = maxQueued; @@ -148,6 +154,16 @@ public Optional getHardCpuLimit() return hardCpuLimit; } + public Optional getQueuedTimeLimit() + { + return queuedTimeLimit; + } + + public Optional getRunningTimeLimit() + { + return runningTimeLimit; + } + @Override public boolean equals(Object other) { @@ -167,7 +183,9 @@ public boolean equals(Object other) subGroups.equals(that.subGroups) && jmxExport.equals(that.jmxExport) && softCpuLimit.equals(that.softCpuLimit) && - hardCpuLimit.equals(that.hardCpuLimit)); + hardCpuLimit.equals(that.hardCpuLimit) && + queuedTimeLimit.equals(that.queuedTimeLimit) && + runningTimeLimit.equals(that.runningTimeLimit)); } // Subgroups not included, used to determine whether a group needs to be reconfigured @@ -184,7 +202,9 @@ public boolean sameConfig(ResourceGroupSpec other) schedulingWeight.equals(other.schedulingWeight) && jmxExport.equals(other.jmxExport) && softCpuLimit.equals(other.softCpuLimit) && - hardCpuLimit.equals(other.hardCpuLimit)); + hardCpuLimit.equals(other.hardCpuLimit) && + queuedTimeLimit.equals(other.queuedTimeLimit) && + runningTimeLimit.equals(other.runningTimeLimit)); } @Override @@ -200,7 +220,9 @@ public int hashCode() subGroups, jmxExport, softCpuLimit, - hardCpuLimit); + hardCpuLimit, + queuedTimeLimit, + runningTimeLimit); } @Override @@ -216,6 +238,8 @@ public String toString() .add("jmxExport", jmxExport) .add("softCpuLimit", softCpuLimit) .add("hardCpuLimit", hardCpuLimit) + .add("queuedTimeLimit", queuedTimeLimit) + .add("runningTimeLimit", runningTimeLimit) .toString(); } } diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java index d5899937dce32..438e687816b02 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/DbResourceGroupConfigurationManager.java @@ -82,7 +82,7 @@ public DbResourceGroupConfigurationManager(ClusterMemoryPoolManager memoryPoolMa } @Override - protected Optional getCpuQuotaPeriodMillis() + protected Optional getCpuQuotaPeriod() { return cpuQuotaPeriod.get(); } diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/MysqlDaoProvider.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/MysqlDaoProvider.java index 994299a5ea292..4872c36294fbd 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/MysqlDaoProvider.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/MysqlDaoProvider.java @@ -14,7 +14,9 @@ package com.facebook.presto.resourceGroups.db; import com.mysql.jdbc.jdbc2.optional.MysqlDataSource; + import org.skife.jdbi.v2.DBI; +import org.skife.jdbi.v2.IDBI; import javax.inject.Inject; import javax.inject.Provider; @@ -32,7 +34,7 @@ public MysqlDaoProvider(DbResourceGroupConfig config) requireNonNull(config, "DbResourceGroupConfig is null"); MysqlDataSource dataSource = new MysqlDataSource(); dataSource.setURL(config.getConfigDbUrl()); - DBI dbi = new DBI(dataSource); + IDBI dbi = new DBI(dataSource); this.dao = dbi.onDemand(ResourceGroupsDao.class); } diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupSpecBuilder.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupSpecBuilder.java index 121c478fade89..18324daba840a 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupSpecBuilder.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupSpecBuilder.java @@ -38,6 +38,8 @@ public class ResourceGroupSpecBuilder private final Optional jmxExport; private final Optional softCpuLimit; private final Optional hardCpuLimit; + private final Optional queuedTimeLimit; + private final Optional runningTimeLimit; private final Optional parentId; private final ImmutableList.Builder subGroups = ImmutableList.builder(); @@ -52,6 +54,8 @@ public class ResourceGroupSpecBuilder Optional jmxExport, Optional softCpuLimit, Optional hardCpuLimit, + Optional queuedTimeLimit, + Optional runningTimeLimit, Optional parentId ) { @@ -65,6 +69,8 @@ public class ResourceGroupSpecBuilder this.jmxExport = requireNonNull(jmxExport, "jmxExport is null"); this.softCpuLimit = requireNonNull(softCpuLimit, "softCpuLimit is null").map(Duration::valueOf); this.hardCpuLimit = requireNonNull(hardCpuLimit, "hardCpuLimit is null").map(Duration::valueOf); + this.queuedTimeLimit = requireNonNull(queuedTimeLimit, "queuedTimeLimit is null").map(Duration::valueOf); + this.runningTimeLimit = requireNonNull(runningTimeLimit, "runningTimeLimit is null").map(Duration::valueOf); this.parentId = parentId; } @@ -110,9 +116,9 @@ public ResourceGroupSpec build() Optional.of(subGroups.build()), jmxExport, softCpuLimit, - hardCpuLimit - - ); + hardCpuLimit, + queuedTimeLimit, + runningTimeLimit); } public static class Mapper @@ -142,6 +148,8 @@ public ResourceGroupSpecBuilder map(int index, ResultSet resultSet, StatementCon if (resultSet.wasNull()) { parentId = Optional.empty(); } + Optional queuedTimeLimit = Optional.ofNullable(resultSet.getString("queued_time_limit")); + Optional runningTimeLimit = Optional.ofNullable(resultSet.getString("running_time_limit")); return new ResourceGroupSpecBuilder( id, nameTemplate, @@ -153,6 +161,8 @@ public ResourceGroupSpecBuilder map(int index, ResultSet resultSet, StatementCon jmxExport, softCpuLimit, hardCpuLimit, + queuedTimeLimit, + runningTimeLimit, parentId ); } diff --git a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupsDao.java b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupsDao.java index c3103d2bac4db..09ce7c3d757dc 100644 --- a/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupsDao.java +++ b/presto-resource-group-managers/src/main/java/com/facebook/presto/resourceGroups/db/ResourceGroupsDao.java @@ -43,6 +43,8 @@ public interface ResourceGroupsDao " jmx_export BOOLEAN NULL,\n" + " soft_cpu_limit VARCHAR(128) NULL,\n" + " hard_cpu_limit VARCHAR(128) NULL,\n" + + " queued_time_limit VARCHAR(128) NULL,\n" + + " running_time_limit VARCHAR(128) NULL,\n" + " parent BIGINT NULL,\n" + " PRIMARY KEY (resource_group_id),\n" + " FOREIGN KEY (parent) REFERENCES resource_groups (resource_group_id)\n" + @@ -50,7 +52,8 @@ public interface ResourceGroupsDao void createResourceGroupsTable(); @SqlQuery("SELECT resource_group_id, name, soft_memory_limit, max_queued, max_running," + - " scheduling_policy, scheduling_weight, jmx_export, soft_cpu_limit, hard_cpu_limit, parent\n" + + " scheduling_policy, scheduling_weight, jmx_export, soft_cpu_limit, hard_cpu_limit, " + + " queued_time_limit, running_time_limit, parent\n" + "FROM resource_groups") @Mapper(ResourceGroupSpecBuilder.Mapper.class) List getResourceGroups(); diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java index b65857a186565..b06cf2ff10e43 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestFileResourceGroupConfigurationManager.java @@ -69,6 +69,8 @@ public void testConfiguration() assertEquals(global.getSchedulingPolicy(), WEIGHTED); assertEquals(global.getSchedulingWeight(), 0); assertEquals(global.getJmxExport(), true); + assertEquals(global.getQueuedTimeLimit(), new Duration(1, HOURS)); + assertEquals(global.getRunningTimeLimit(), new Duration(1, HOURS)); ResourceGroup sub = new TestingResourceGroup(new ResourceGroupId(new ResourceGroupId("global"), "sub")); manager.configure(sub, new SelectionContext(true, "user", Optional.empty(), 1)); @@ -78,6 +80,8 @@ public void testConfiguration() assertEquals(sub.getSchedulingPolicy(), null); assertEquals(sub.getSchedulingWeight(), 5); assertEquals(sub.getJmxExport(), false); + assertEquals(global.getQueuedTimeLimit(), new Duration(1, HOURS)); + assertEquals(global.getRunningTimeLimit(), new Duration(1, HOURS)); } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Selector refers to nonexistent group: a.b.c.X") diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestingResourceGroup.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestingResourceGroup.java index 122ded87c3dd0..0c886013c0908 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestingResourceGroup.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/TestingResourceGroup.java @@ -34,6 +34,8 @@ public class TestingResourceGroup private int schedulingWeight; private SchedulingPolicy policy; private boolean jmxExport; + private Duration queuedTimeLimit; + private Duration runningTimeLimit; public TestingResourceGroup(ResourceGroupId id) { @@ -153,4 +155,27 @@ public void setJmxExport(boolean export) { jmxExport = export; } + @Override + public Duration getQueuedTimeLimit() + { + return queuedTimeLimit; + } + + @Override + public void setQueuedTimeLimit(Duration queuedTimeLimit) + { + this.queuedTimeLimit = queuedTimeLimit; + } + + @Override + public Duration getRunningTimeLimit() + { + return runningTimeLimit; + } + + @Override + public void setRunningTimeLimit(Duration runningTimeLimit) + { + this.runningTimeLimit = runningTimeLimit; + } } diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/H2ResourceGroupsDao.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/H2ResourceGroupsDao.java index 1034a32e8060c..36c5d27cb7729 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/H2ResourceGroupsDao.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/H2ResourceGroupsDao.java @@ -29,8 +29,8 @@ void insertResourceGroupsGlobalProperties( void updateResourceGroupsGlobalProperties(@Bind("name") String name); @SqlUpdate("INSERT INTO resource_groups\n" + - "(resource_group_id, name, soft_memory_limit, max_queued, max_running, scheduling_policy, scheduling_weight, jmx_export, soft_cpu_limit, hard_cpu_limit, parent)\n" + - "VALUES (:resource_group_id, :name, :soft_memory_limit, :max_queued, :max_running, :scheduling_policy, :scheduling_weight, :jmx_export, :soft_cpu_limit, :hard_cpu_limit, :parent)") + "(resource_group_id, name, soft_memory_limit, max_queued, max_running, scheduling_policy, scheduling_weight, jmx_export, soft_cpu_limit, hard_cpu_limit, queued_time_limit, running_time_limit, parent)\n" + + "VALUES (:resource_group_id, :name, :soft_memory_limit, :max_queued, :max_running, :scheduling_policy, :scheduling_weight, :jmx_export, :soft_cpu_limit, :hard_cpu_limit, :queued_time_limit, :running_time_limit, :parent)") void insertResourceGroup( @Bind("resource_group_id") long resourceGroupId, @Bind("name") String name, @@ -42,6 +42,8 @@ void insertResourceGroup( @Bind("jmx_export") Boolean jmxExport, @Bind("soft_cpu_limit") String softCpuLimit, @Bind("hard_cpu_limit") String hardCpuLimit, + @Bind("queued_time_limit") String queuedTimeLimit, + @Bind("running_time_limit") String runningTimeLimit, @Bind("parent") Long parent ); @@ -56,6 +58,8 @@ void insertResourceGroup( ", jmx_export = :jmx_export\n" + ", soft_cpu_limit = :soft_cpu_limit\n" + ", hard_cpu_limit = :hard_cpu_limit\n" + + ", queued_time_limit = :queued_time_limit\n" + + ", running_time_limit = :running_time_limit\n" + ", parent = :parent\n" + "WHERE resource_group_id = :resource_group_id") void updateResourceGroup( @@ -69,6 +73,8 @@ void updateResourceGroup( @Bind("jmx_export") Boolean jmxExport, @Bind("soft_cpu_limit") String softCpuLimit, @Bind("hard_cpu_limit") String hardCpuLimit, + @Bind("queued_time_limit") String queuedTimeLimit, + @Bind("running_time_limit") String runningTimeLimit, @Bind("parent") Long parent); @SqlUpdate("DELETE FROM resource_groups WHERE resource_group_id = :resource_group_id") diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java index 8bc5492339c7c..316c48d067606 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestDbResourceGroupConfigurationManager.java @@ -52,19 +52,19 @@ public void testConfiguration() dao.createResourceGroupsTable(); dao.createSelectorsTable(); dao.insertResourceGroupsGlobalProperties("cpu_quota_period", "1h"); - dao.insertResourceGroup(1, "global", "1MB", 1000, 100, "weighted", null, true, "1h", "1d", null); - dao.insertResourceGroup(2, "sub", "2MB", 4, 3, null, 5, null, null, null, 1L); + dao.insertResourceGroup(1, "global", "1MB", 1000, 100, "weighted", null, true, "1h", "1d", "1h", "1h", null); + dao.insertResourceGroup(2, "sub", "2MB", 4, 3, null, 5, null, null, null, "1h", "1h", 1L); dao.insertSelector(2, null, null); DbResourceGroupConfigurationManager manager = new DbResourceGroupConfigurationManager((poolId, listener) -> { }, daoProvider.get()); AtomicBoolean exported = new AtomicBoolean(); InternalResourceGroup global = new InternalResourceGroup.RootInternalResourceGroup("global", (group, export) -> exported.set(export), directExecutor()); manager.configure(global, new SelectionContext(true, "user", Optional.empty(), 1)); - assertEqualsResourceGroup(global, "1MB", 1000, 100, WEIGHTED, DEFAULT_WEIGHT, true, new Duration(1, HOURS), new Duration(1, DAYS)); + assertEqualsResourceGroup(global, "1MB", 1000, 100, WEIGHTED, DEFAULT_WEIGHT, true, new Duration(1, HOURS), new Duration(1, DAYS), new Duration(1, HOURS), new Duration(1, HOURS)); exported.set(false); InternalResourceGroup sub = global.getOrCreateSubGroup("sub"); manager.configure(sub, new SelectionContext(true, "user", Optional.empty(), 1)); - assertEqualsResourceGroup(sub, "2MB", 4, 3, FAIR, 5, false, new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS)); + assertEqualsResourceGroup(sub, "2MB", 4, 3, FAIR, 5, false, new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(1, HOURS), new Duration(1, HOURS)); } @Test @@ -75,9 +75,9 @@ public void testDuplicates() dao.createResourceGroupsGlobalPropertiesTable(); dao.createResourceGroupsTable(); dao.createSelectorsTable(); - dao.insertResourceGroup(1, "global", "1MB", 1000, 100, null, null, null, null, null, null); + dao.insertResourceGroup(1, "global", "1MB", 1000, 100, null, null, null, null, null, null, null, null); try { - dao.insertResourceGroup(1, "global", "1MB", 1000, 100, null, null, null, null, null, null); + dao.insertResourceGroup(1, "global", "1MB", 1000, 100, null, null, null, null, null, null, null, null); fail("Expected to fail"); } catch (RuntimeException ex) { @@ -91,10 +91,10 @@ public void testDuplicates() dao.createResourceGroupsGlobalPropertiesTable(); dao.createResourceGroupsTable(); dao.createSelectorsTable(); - dao.insertResourceGroup(1, "global", "1MB", 1000, 100, null, null, null, null, null, null); - dao.insertResourceGroup(2, "sub", "1MB", 1000, 100, null, null, null, null, null, 1L); + dao.insertResourceGroup(1, "global", "1MB", 1000, 100, null, null, null, null, null, null, null, null); + dao.insertResourceGroup(2, "sub", "1MB", 1000, 100, null, null, null, null, null, null, null, 1L); try { - dao.insertResourceGroup(2, "sub", "1MB", 1000, 100, null, null, null, null, null, 1L); + dao.insertResourceGroup(2, "sub", "1MB", 1000, 100, null, null, null, null, null, null, null, 1L); } catch (RuntimeException ex) { assertTrue(ex instanceof UnableToExecuteStatementException); @@ -113,8 +113,8 @@ public void testMissing() dao.createResourceGroupsGlobalPropertiesTable(); dao.createResourceGroupsTable(); dao.createSelectorsTable(); - dao.insertResourceGroup(1, "global", "1MB", 1000, 100, "weighted", null, true, "1h", "1d", null); - dao.insertResourceGroup(2, "sub", "2MB", 4, 3, null, 5, null, null, null, 1L); + dao.insertResourceGroup(1, "global", "1MB", 1000, 100, "weighted", null, true, "1h", "1d", null, null, null); + dao.insertResourceGroup(2, "sub", "2MB", 4, 3, null, 5, null, null, null, null, null, 1L); dao.insertResourceGroupsGlobalProperties("cpu_quota_period", "1h"); dao.insertSelector(2, null, null); DbResourceGroupConfigurationManager manager = new DbResourceGroupConfigurationManager((poolId, listener) -> { @@ -133,8 +133,8 @@ public void testReconfig() dao.createResourceGroupsGlobalPropertiesTable(); dao.createResourceGroupsTable(); dao.createSelectorsTable(); - dao.insertResourceGroup(1, "global", "1MB", 1000, 100, "weighted", null, true, "1h", "1d", null); - dao.insertResourceGroup(2, "sub", "2MB", 4, 3, null, 5, null, null, null, 1L); + dao.insertResourceGroup(1, "global", "1MB", 1000, 100, "weighted", null, true, "1h", "1d", null, null, null); + dao.insertResourceGroup(2, "sub", "2MB", 4, 3, null, 5, null, null, null, null, null, 1L); dao.insertSelector(2, null, null); dao.insertResourceGroupsGlobalProperties("cpu_quota_period", "1h"); DbResourceGroupConfigurationManager manager = new DbResourceGroupConfigurationManager( @@ -147,13 +147,13 @@ public void testReconfig() InternalResourceGroup globalSub = global.getOrCreateSubGroup("sub"); manager.configure(globalSub, new SelectionContext(true, "user", Optional.empty(), 1)); // Verify record exists - assertEqualsResourceGroup(globalSub, "2MB", 4, 3, FAIR, 5, false, new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS)); - dao.updateResourceGroup(2, "sub", "3MB", 2, 1, "weighted", 6, true, "1h", "1d", 1L); + assertEqualsResourceGroup(globalSub, "2MB", 4, 3, FAIR, 5, false, new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS)); + dao.updateResourceGroup(2, "sub", "3MB", 2, 1, "weighted", 6, true, "1h", "1d", null, null, 1L); do { MILLISECONDS.sleep(500); } while(globalSub.getJmxExport() == false); // Verify update - assertEqualsResourceGroup(globalSub, "3MB", 2, 1, WEIGHTED, 6, true, new Duration(1, HOURS), new Duration(1, DAYS)); + assertEqualsResourceGroup(globalSub, "3MB", 2, 1, WEIGHTED, 6, true, new Duration(1, HOURS), new Duration(1, DAYS), new Duration(Long.MAX_VALUE, MILLISECONDS), new Duration(Long.MAX_VALUE, MILLISECONDS)); // Verify delete dao.deleteSelectors(2); dao.deleteResourceGroup(2); @@ -171,7 +171,9 @@ private static void assertEqualsResourceGroup( int schedulingWeight, boolean jmxExport, Duration softCpuLimit, - Duration hardCpuLimit) + Duration hardCpuLimit, + Duration queuedTimeLimit, + Duration runningTimeLimit) { assertEquals(group.getSoftMemoryLimit(), DataSize.valueOf(softMemoryLimit)); assertEquals(group.getInfo().getMaxQueuedQueries(), maxQueued); @@ -181,5 +183,7 @@ private static void assertEqualsResourceGroup( assertEquals(group.getJmxExport(), jmxExport); assertEquals(group.getSoftCpuLimit(), softCpuLimit); assertEquals(group.getHardCpuLimit(), hardCpuLimit); + assertEquals(group.getQueuedTimeLimit(), queuedTimeLimit); + assertEquals(group.getRunningTimeLimit(), runningTimeLimit); } } diff --git a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestResourceGroupsDao.java b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestResourceGroupsDao.java index 906562a034d7b..cc93bf552d1a8 100644 --- a/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestResourceGroupsDao.java +++ b/presto-resource-group-managers/src/test/java/com/facebook/presto/resourceGroups/db/TestResourceGroupsDao.java @@ -49,19 +49,19 @@ public void testResourceGroups() private static void testResourceGroupInsert(H2ResourceGroupsDao dao, Map map) { - dao.insertResourceGroup(1, "global", "100%", 100, 100, null, null, null, null, null, null); - dao.insertResourceGroup(2, "bi", "50%", 50, 50, null, null, null, null, null, 1L); + dao.insertResourceGroup(1, "global", "100%", 100, 100, null, null, null, null, null, null, null, null); + dao.insertResourceGroup(2, "bi", "50%", 50, 50, null, null, null, null, null, null, null, 1L); List records = dao.getResourceGroups(); assertEquals(records.size(), 2); - map.put(1L, new ResourceGroupSpecBuilder(1, new ResourceGroupNameTemplate("global"), "100%", 100, 100, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), null)); - map.put(2L, new ResourceGroupSpecBuilder(2, new ResourceGroupNameTemplate("bi"), "50%", 50, 50, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(1L))); + map.put(1L, new ResourceGroupSpecBuilder(1, new ResourceGroupNameTemplate("global"), "100%", 100, 100, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), null)); + map.put(2L, new ResourceGroupSpecBuilder(2, new ResourceGroupNameTemplate("bi"), "50%", 50, 50, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(1L))); compareResourceGroups(map, records); } private static void testResourceGroupUpdate(H2ResourceGroupsDao dao, Map map) { - dao.updateResourceGroup(2, "bi", "40%", 40, 30, null, null, true, null, null, 1L); - ResourceGroupSpecBuilder updated = new ResourceGroupSpecBuilder(2, new ResourceGroupNameTemplate("bi"), "40%", 40, 30, Optional.empty(), Optional.empty(), Optional.of(true), Optional.empty(), Optional.empty(), Optional.of(1L)); + dao.updateResourceGroup(2, "bi", "40%", 40, 30, null, null, true, null, null, null, null, 1L); + ResourceGroupSpecBuilder updated = new ResourceGroupSpecBuilder(2, new ResourceGroupNameTemplate("bi"), "40%", 40, 30, Optional.empty(), Optional.empty(), Optional.of(true), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(1L)); map.put(2L, updated); compareResourceGroups(map, dao.getResourceGroups()); } @@ -101,9 +101,9 @@ private static void testSelectorInsert(H2ResourceGroupsDao dao, Map records = dao.getSelectors(); @@ -117,14 +117,12 @@ private static void testSelectorUpdate(H2ResourceGroupsDao dao, Map map) { - map.remove(2); SelectorRecord updated = new SelectorRecord(2, Optional.empty(), Optional.empty()); map.put(2L, updated); dao.updateSelector(2, null, null, "ping.*", "ping_source"); @@ -133,7 +131,6 @@ private static void testSelectorUpdateNull(H2ResourceGroupsDao dao, Map com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-server-rpm diff --git a/presto-server/pom.xml b/presto-server/pom.xml index 81911bd66a4df..d41304cfd8b31 100644 --- a/presto-server/pom.xml +++ b/presto-server/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-server diff --git a/presto-server/src/main/provisio/presto.xml b/presto-server/src/main/provisio/presto.xml index 9bf92f721dbd7..6e11d70860aa2 100644 --- a/presto-server/src/main/provisio/presto.xml +++ b/presto-server/src/main/provisio/presto.xml @@ -145,4 +145,10 @@ + + + + + + diff --git a/presto-spi/pom.xml b/presto-spi/pom.xml index 22eed1c882fe1..05c113daaba3c 100644 --- a/presto-spi/pom.xml +++ b/presto-spi/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-spi @@ -27,6 +27,12 @@ units + + com.google.code.findbugs + annotations + provided + + com.fasterxml.jackson.core jackson-annotations @@ -50,6 +56,12 @@ test + + it.unimi.dsi + fastutil + test + + com.google.guava guava diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Page.java b/presto-spi/src/main/java/com/facebook/presto/spi/Page.java index 03267092c2b8b..e3fa7903ef057 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Page.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Page.java @@ -252,6 +252,17 @@ private static int determinePositionCount(Block... blocks) return blocks[0].getPositionCount(); } + public static Page mask(Page page, int[] retainedPositions) + { + requireNonNull(page, "page is null"); + requireNonNull(retainedPositions, "retainedPositions is null"); + + Block[] blocks = Arrays.stream(page.getBlocks()) + .map(block -> new DictionaryBlock(block, retainedPositions)) + .toArray(Block[]::new); + return new Page(retainedPositions.length, blocks); + } + private static class DictionaryBlockIndexes { private final List blocks = new ArrayList<>(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java index 7cad65c976f9a..b83686ffc1249 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayBlock.java @@ -90,7 +90,7 @@ public Block getRegion(int position, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { int positionCount = getPositionCount(); if (position < 0 || length < 0 || position + length > positionCount) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java index 03e825f1787f2..6a191d4945c7e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractFixedWidthBlock.java @@ -165,7 +165,7 @@ public boolean isNull(int position) } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { int positionCount = getPositionCount(); if (positionOffset < 0 || length < 0 || positionOffset + length > positionCount) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java index 0a3f465a50d70..83d30a2ea3ada 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractInterleavedBlock.java @@ -253,7 +253,7 @@ public Block copyRegion(int position, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { if (position == 0 && length == getPositionCount()) { // Calculation of getRegionSizeInBytes is expensive in this class. diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java new file mode 100644 index 0000000000000..d729eddf04eca --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractMapBlock.java @@ -0,0 +1,280 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.type.Type; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public abstract class AbstractMapBlock + implements Block +{ + // inverse of hash fill ratio, must be integer + static final int HASH_MULTIPLIER = 2; + + protected final Type keyType; + protected final MethodHandle keyNativeHashCode; + protected final MethodHandle keyBlockNativeEquals; + + public AbstractMapBlock(Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals) + { + this.keyType = requireNonNull(keyType, "keyType is null"); + // keyNativeHashCode can only be null due to map block kill switch. deprecated.new-map-block + this.keyNativeHashCode = keyNativeHashCode; + // keyBlockNativeEquals can only be null due to map block kill switch. deprecated.new-map-block + this.keyBlockNativeEquals = keyBlockNativeEquals; + } + + protected abstract Block getKeys(); + + protected abstract Block getValues(); + + protected abstract int[] getHashTables(); + + /** + * offset is entry-based, not position-based. In other words, + * if offset[1] is 6, it means the first map has 6 key-value pairs, + * not 6 key/values (which would be 3 pairs). + */ + protected abstract int[] getOffsets(); + + /** + * offset is entry-based, not position-based. (see getOffsets) + */ + protected abstract int getOffsetBase(); + + protected abstract boolean[] getMapIsNull(); + + private int getOffset(int position) + { + return getOffsets()[position + getOffsetBase()]; + } + + @Override + public BlockEncoding getEncoding() + { + return new MapBlockEncoding(keyType, keyBlockNativeEquals, keyNativeHashCode, getKeys().getEncoding(), getValues().getEncoding()); + } + + @Override + public Block copyPositions(List positions) + { + int[] newOffsets = new int[positions.size() + 1]; + boolean[] newMapIsNull = new boolean[positions.size()]; + + List entriesPositions = new ArrayList<>(); + int newPosition = 0; + for (int position : positions) { + if (isNull(position)) { + newMapIsNull[newPosition] = true; + newOffsets[newPosition + 1] = newOffsets[newPosition]; + } + else { + int entriesStartOffset = getOffset(position); + int entriesEndOffset = getOffset(position + 1); + int entryCount = entriesEndOffset - entriesStartOffset; + + newOffsets[newPosition + 1] = newOffsets[newPosition] + entryCount; + + for (int elementIndex = entriesStartOffset; elementIndex < entriesEndOffset; elementIndex++) { + entriesPositions.add(elementIndex); + } + } + newPosition++; + } + + int[] hashTable = getHashTables(); + int[] newHashTable = new int[newOffsets[newOffsets.length - 1] * HASH_MULTIPLIER]; + int newHashIndex = 0; + for (int position : positions) { + int entriesStartOffset = getOffset(position); + int entriesEndOffset = getOffset(position + 1); + for (int hashIndex = entriesStartOffset * HASH_MULTIPLIER; hashIndex < entriesEndOffset * HASH_MULTIPLIER; hashIndex++) { + newHashTable[newHashIndex] = hashTable[hashIndex]; + newHashIndex++; + } + } + + Block newKeys = getKeys().copyPositions(entriesPositions); + Block newValues = getValues().copyPositions(entriesPositions); + return new MapBlock(0, positions.size(), newMapIsNull, newOffsets, newKeys, newValues, newHashTable, keyType, keyBlockNativeEquals, keyNativeHashCode); + } + + @Override + public Block getRegion(int position, int length) + { + int positionCount = getPositionCount(); + if (position < 0 || length < 0 || position + length > positionCount) { + throw new IndexOutOfBoundsException("Invalid position " + position + " in block with " + positionCount + " positions"); + } + + if (position == 0 && length == positionCount) { + return this; + } + + return new MapBlock( + position + getOffsetBase(), + length, + getMapIsNull(), + getOffsets(), + getKeys(), + getValues(), + getHashTables(), + keyType, + keyBlockNativeEquals, keyNativeHashCode + ); + } + + @Override + public long getRegionSizeInBytes(int position, int length) + { + int positionCount = getPositionCount(); + if (position < 0 || length < 0 || position + length > positionCount) { + throw new IndexOutOfBoundsException("Invalid position " + position + " in block with " + positionCount + " positions"); + } + + int entriesStart = getOffsets()[getOffsetBase() + position]; + int entriesEnd = getOffsets()[getOffsetBase() + position + length]; + int entryCount = entriesEnd - entriesStart; + + return getKeys().getRegionSizeInBytes(entriesStart, entryCount) + + getValues().getRegionSizeInBytes(entriesStart, entryCount) + + (Integer.BYTES + Byte.BYTES) * length + + Integer.BYTES * HASH_MULTIPLIER * entryCount; + } + + @Override + public Block copyRegion(int position, int length) + { + int positionCount = getPositionCount(); + if (position < 0 || length < 0 || position + length > positionCount) { + throw new IndexOutOfBoundsException("Invalid position " + position + " in block with " + positionCount + " positions"); + } + + int startValueOffset = getOffset(position); + int endValueOffset = getOffset(position + length); + Block newKeys = getKeys().copyRegion(startValueOffset, endValueOffset - startValueOffset); + Block newValues = getValues().copyRegion(startValueOffset, endValueOffset - startValueOffset); + + int[] newOffsets = new int[length + 1]; + for (int i = 1; i < newOffsets.length; i++) { + newOffsets[i] = getOffset(position + i) - startValueOffset; + } + boolean[] newValueIsNull = Arrays.copyOfRange(getMapIsNull(), position + getOffsetBase(), position + getOffsetBase() + length); + int[] newHashTable = Arrays.copyOfRange(getHashTables(), startValueOffset * HASH_MULTIPLIER, endValueOffset * HASH_MULTIPLIER); + + return new MapBlock( + 0, + length, + newValueIsNull, + newOffsets, + newKeys, + newValues, + newHashTable, + keyType, + keyBlockNativeEquals, keyNativeHashCode + ); + } + + @Override + public T getObject(int position, Class clazz) + { + if (clazz != Block.class) { + throw new IllegalArgumentException("clazz must be Block.class"); + } + checkReadablePosition(position); + + int startEntryOffset = getOffset(position); + int endEntryOffset = getOffset(position + 1); + return clazz.cast(new SingleMapBlock( + startEntryOffset * 2, + (endEntryOffset - startEntryOffset) * 2, + getKeys(), + getValues(), + getHashTables(), + keyType, + keyNativeHashCode, + keyBlockNativeEquals)); + } + + @Override + public void writePositionTo(int position, BlockBuilder blockBuilder) + { + checkReadablePosition(position); + BlockBuilder entryBuilder = blockBuilder.beginBlockEntry(); + int startValueOffset = getOffset(position); + int endValueOffset = getOffset(position + 1); + for (int i = startValueOffset; i < endValueOffset; i++) { + if (getKeys().isNull(i)) { + entryBuilder.appendNull(); + } + else { + getKeys().writePositionTo(i, entryBuilder); + entryBuilder.closeEntry(); + } + if (getValues().isNull(i)) { + entryBuilder.appendNull(); + } + else { + getValues().writePositionTo(i, entryBuilder); + entryBuilder.closeEntry(); + } + } + } + + @Override + public Block getSingleValueBlock(int position) + { + checkReadablePosition(position); + + int startValueOffset = getOffset(position); + int endValueOffset = getOffset(position + 1); + int valueLength = endValueOffset - startValueOffset; + Block newKeys = getKeys().copyRegion(startValueOffset, valueLength); + Block newValues = getValues().copyRegion(startValueOffset, valueLength); + int[] newHashTable = Arrays.copyOfRange(getHashTables(), startValueOffset * HASH_MULTIPLIER, endValueOffset * HASH_MULTIPLIER); + + return new MapBlock( + 0, + 1, + new boolean[] {isNull(position)}, + new int[] {0, valueLength}, + newKeys, + newValues, + newHashTable, + keyType, + keyBlockNativeEquals, + keyNativeHashCode); + } + + @Override + public boolean isNull(int position) + { + checkReadablePosition(position); + return getMapIsNull()[position + getOffsetBase()]; + } + + private void checkReadablePosition(int position) + { + if (position < 0 || position >= getPositionCount()) { + throw new IllegalArgumentException("position is not valid"); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayElementBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleArrayBlock.java similarity index 95% rename from presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayElementBlock.java rename to presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleArrayBlock.java index 164503ce9e313..9c5e9381782d7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractArrayElementBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleArrayBlock.java @@ -17,12 +17,12 @@ import java.util.List; -public abstract class AbstractArrayElementBlock +public abstract class AbstractSingleArrayBlock implements Block { protected final int start; - protected AbstractArrayElementBlock(int start) + protected AbstractSingleArrayBlock(int start) { this.start = start; } @@ -151,7 +151,7 @@ public boolean isNull(int position) @Override public BlockEncoding getEncoding() { - // ArrayElementBlockEncoding does not exist + // SingleArrayBlockEncoding does not exist throw new UnsupportedOperationException(); } @@ -168,7 +168,7 @@ public Block getRegion(int position, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { throw new UnsupportedOperationException(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleMapBlock.java new file mode 100644 index 0000000000000..0cdba8fb24156 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/AbstractSingleMapBlock.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import io.airlift.slice.Slice; + +import java.util.List; + +public abstract class AbstractSingleMapBlock + implements Block +{ + private final int offset; + private final Block keyBlock; + private final Block valueBlock; + + public AbstractSingleMapBlock(int offset, Block keyBlock, Block valueBlock) + { + this.offset = offset; + this.keyBlock = keyBlock; + this.valueBlock = valueBlock; + } + + private int getAbsolutePosition(int position) + { + if (position < 0 || position >= getPositionCount()) { + throw new IllegalArgumentException("position is not valid"); + } + return position + offset; + } + + @Override + public boolean isNull(int position) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + if (keyBlock.isNull(position / 2)) { + throw new IllegalStateException("Map key is null"); + } + return false; + } + else { + return valueBlock.isNull(position / 2); + } + } + + @Override + public byte getByte(int position, int offset) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getByte(position / 2, offset); + } + else { + return valueBlock.getByte(position / 2, offset); + } + } + + @Override + public short getShort(int position, int offset) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getShort(position / 2, offset); + } + else { + return valueBlock.getShort(position / 2, offset); + } + } + + @Override + public int getInt(int position, int offset) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getInt(position / 2, offset); + } + else { + return valueBlock.getInt(position / 2, offset); + } + } + + @Override + public long getLong(int position, int offset) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getLong(position / 2, offset); + } + else { + return valueBlock.getLong(position / 2, offset); + } + } + + @Override + public Slice getSlice(int position, int offset, int length) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getSlice(position / 2, offset, length); + } + else { + return valueBlock.getSlice(position / 2, offset, length); + } + } + + @Override + public int getSliceLength(int position) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getSliceLength(position / 2); + } + else { + return valueBlock.getSliceLength(position / 2); + } + } + + @Override + public boolean bytesEqual(int position, int offset, Slice otherSlice, int otherOffset, int length) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.bytesEqual(position / 2, offset, otherSlice, otherOffset, length); + } + else { + return valueBlock.bytesEqual(position / 2, offset, otherSlice, otherOffset, length); + } + } + + @Override + public int bytesCompare(int position, int offset, int length, Slice otherSlice, int otherOffset, int otherLength) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.bytesCompare(position / 2, offset, length, otherSlice, otherOffset, otherLength); + } + else { + return valueBlock.bytesCompare(position / 2, offset, length, otherSlice, otherOffset, otherLength); + } + } + + @Override + public void writeBytesTo(int position, int offset, int length, BlockBuilder blockBuilder) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + keyBlock.writeBytesTo(position / 2, offset, length, blockBuilder); + } + else { + valueBlock.writeBytesTo(position / 2, offset, length, blockBuilder); + } + } + + @Override + public boolean equals(int position, int offset, Block otherBlock, int otherPosition, int otherOffset, int length) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.equals(position / 2, offset, otherBlock, otherPosition, otherOffset, length); + } + else { + return valueBlock.equals(position / 2, offset, otherBlock, otherPosition, otherOffset, length); + } + } + + @Override + public long hash(int position, int offset, int length) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.hash(position / 2, offset, length); + } + else { + return valueBlock.hash(position / 2, offset, length); + } + } + + @Override + public T getObject(int position, Class clazz) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getObject(position / 2, clazz); + } + else { + return valueBlock.getObject(position / 2, clazz); + } + } + + @Override + public void writePositionTo(int position, BlockBuilder blockBuilder) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + keyBlock.writePositionTo(position / 2, blockBuilder); + } + else { + valueBlock.writePositionTo(position / 2, blockBuilder); + } + } + + @Override + public Block getSingleValueBlock(int position) + { + position = getAbsolutePosition(position); + if (position % 2 == 0) { + return keyBlock.getSingleValueBlock(position / 2); + } + else { + return valueBlock.getSingleValueBlock(position / 2); + } + } + + @Override + public long getRegionSizeInBytes(int position, int length) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block copyPositions(List positions) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block getRegion(int positionOffset, int length) + { + throw new UnsupportedOperationException(); + } + + @Override + public Block copyRegion(int position, int length) + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java index 5f2872a963725..6ad6271679912 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlock.java @@ -15,7 +15,8 @@ import org.openjdk.jol.info.ClassLayout; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; +import java.util.function.BiConsumer; + import static io.airlift.slice.SizeOf.sizeOf; import static java.util.Objects.requireNonNull; @@ -30,8 +31,8 @@ public class ArrayBlock private final Block values; private final int[] offsets; - private int sizeInBytes; - private final int retainedSizeInBytes; + private long sizeInBytes; + private final long retainedSizeInBytes; public ArrayBlock(int positionCount, boolean[] valueIsNull, int[] offsets, Block values) { @@ -65,7 +66,7 @@ public ArrayBlock(int positionCount, boolean[] valueIsNull, int[] offsets, Block this.values = requireNonNull(values); sizeInBytes = -1; - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + values.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(valueIsNull)); + retainedSizeInBytes = INSTANCE_SIZE + values.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(valueIsNull); } @Override @@ -75,7 +76,7 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { // this is racy but is safe because sizeInBytes is an int and the calculation is stable if (sizeInBytes < 0) { @@ -88,15 +89,24 @@ private void calculateSize() { int valueStart = offsets[arrayOffset]; int valueEnd = offsets[arrayOffset + positionCount]; - sizeInBytes = intSaturatedCast(values.getRegionSizeInBytes(valueStart, valueEnd - valueStart) + ((Integer.BYTES + Byte.BYTES) * (long) this.positionCount)); + sizeInBytes = values.getRegionSizeInBytes(valueStart, valueEnd - valueStart) + ((Integer.BYTES + Byte.BYTES) * (long) this.positionCount); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, values.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override protected Block getValues() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java index ddd6cda11319e..49f26b6c3345a 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayBlockBuilder.java @@ -16,10 +16,12 @@ import com.facebook.presto.spi.type.Type; import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; import static java.util.Objects.requireNonNull; @@ -28,10 +30,11 @@ public class ArrayBlockBuilder extends AbstractArrayBlock implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(ArrayBlockBuilder.class).instanceSize() + BlockBuilderStatus.INSTANCE_SIZE; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ArrayBlockBuilder.class).instanceSize(); private int positionCount; + @Nullable private BlockBuilderStatus blockBuilderStatus; private boolean initialized; private int initialEntryCount; @@ -40,9 +43,9 @@ public class ArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private final BlockBuilder values; - private int currentEntrySize; + private boolean currentEntryOpened; - private int retainedSizeInBytes; + private long retainedSizeInBytes; /** * Caller of this constructor is responsible for making sure `valuesBlock` is constructed with the same `blockBuilderStatus` as the one in the argument @@ -74,9 +77,9 @@ public ArrayBlockBuilder(Type elementType, BlockBuilderStatus blockBuilderStatus /** * Caller of this private constructor is responsible for making sure `values` is constructed with the same `blockBuilderStatus` as the one in the argument */ - private ArrayBlockBuilder(BlockBuilderStatus blockBuilderStatus, BlockBuilder values, int expectedEntries) + private ArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, BlockBuilder values, int expectedEntries) { - this.blockBuilderStatus = requireNonNull(blockBuilderStatus, "blockBuilderStatus is null"); + this.blockBuilderStatus = blockBuilderStatus; this.values = requireNonNull(values, "values is null"); this.initialEntryCount = max(expectedEntries, 1); @@ -90,15 +93,24 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return values.getSizeInBytes() + ((Integer.BYTES + Byte.BYTES) * positionCount); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { - return retainedSizeInBytes; + return retainedSizeInBytes + values.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, values.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override @@ -128,8 +140,8 @@ protected boolean[] getValueIsNull() @Override public BlockBuilder writeObject(Object value) { - if (currentEntrySize != 0) { - throw new IllegalStateException("Expected entry size to be exactly " + 0 + " but was " + currentEntrySize); + if (currentEntryOpened) { + throw new IllegalStateException("Expected current entry to be closed but was opened"); } Block block = (Block) value; @@ -143,36 +155,36 @@ public BlockBuilder writeObject(Object value) } } - currentEntrySize++; + currentEntryOpened = true; return this; } @Override - public ArrayElementBlockWriter beginBlockEntry() + public SingleArrayBlockWriter beginBlockEntry() { - if (currentEntrySize != 0) { - throw new IllegalStateException("Expected current entry size to be exactly 0 but was " + currentEntrySize); + if (currentEntryOpened) { + throw new IllegalStateException("Expected current entry to be closed but was closed"); } - currentEntrySize++; - return new ArrayElementBlockWriter(values, values.getPositionCount()); + currentEntryOpened = true; + return new SingleArrayBlockWriter(values, values.getPositionCount()); } @Override public BlockBuilder closeEntry() { - if (currentEntrySize != 1) { - throw new IllegalStateException("Expected entry size to be exactly 1 but was " + currentEntrySize); + if (!currentEntryOpened) { + throw new IllegalStateException("Expected entry to be opened but was closed"); } entryAdded(false); - currentEntrySize = 0; + currentEntryOpened = false; return this; } @Override public BlockBuilder appendNull() { - if (currentEntrySize > 0) { + if (currentEntryOpened) { throw new IllegalStateException("Current entry must be closed before a null can be written"); } @@ -189,7 +201,9 @@ private void entryAdded(boolean isNull) valueIsNull[positionCount] = isNull; positionCount++; - blockBuilderStatus.addBytes(Integer.BYTES + Byte.BYTES); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Integer.BYTES + Byte.BYTES); + } } private void growCapacity() @@ -210,13 +224,16 @@ private void growCapacity() private void updateDataSize() { - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(offsets)); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(offsets); + if (blockBuilderStatus != null) { + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; + } } @Override public ArrayBlock build() { - if (currentEntrySize > 0) { + if (currentEntryOpened) { throw new IllegalStateException("Current entry must be closed before the block can be built"); } return new ArrayBlock(positionCount, valueIsNull, offsets, values.build()); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java index 4f08875259ce7..ad955ec42fa35 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/Block.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import java.util.List; +import java.util.function.BiConsumer; public interface Block { @@ -162,18 +163,28 @@ default int compareTo(int leftPosition, int leftOffset, int leftLength, Block ri /** * Returns the logical size of this block in memory. */ - int getSizeInBytes(); + long getSizeInBytes(); /** * Returns the logical size of {@code block.getRegion(position, length)} in memory. */ - int getRegionSizeInBytes(int position, int length); + long getRegionSizeInBytes(int position, int length); /** * Returns the retained size of this block in memory. * This method is called from the inner most execution loop and must be fast. */ - int getRetainedSizeInBytes(); + long getRetainedSizeInBytes(); + + /** + * {@code consumer} visits each of the internal data container and accepts the size for it. + * This method can be helpful in cases such as memory counting for internal data structure. + * Also, the method should be non-recursive, only visit the elements at the top level, + * and specifically should not call retainedBytesForEachPart on nested blocks + * {@code consumer} should be called at least once with the current block and + * must include the instance size of the current block + */ + void retainedBytesForEachPart(BiConsumer consumer); /** * Get the encoding for this block. diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java index 21487bfa55a9d..d28fae9432b64 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/BlockUtil.java @@ -76,12 +76,4 @@ else if (newSize > MAX_ARRAY_SIZE) { } return (int) newSize; } - - static int intSaturatedCast(long value) - { - if (value > Integer.MAX_VALUE) { - return Integer.MAX_VALUE; - } - return (int) value; - } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java index e85ee5d1b22c5..af9940d3fcdae 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class ByteArrayBlock @@ -32,8 +32,8 @@ public class ByteArrayBlock private final boolean[] valueIsNull; private final byte[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public ByteArrayBlock(int positionCount, boolean[] valueIsNull, byte[] values) { @@ -61,28 +61,36 @@ public ByteArrayBlock(int positionCount, boolean[] valueIsNull, byte[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast((INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values))); + sizeInBytes = (Byte.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = (INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) length); + return (Byte.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java index f5312f7713ed1..260880c4c7a68 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ByteArrayBlockBuilder.java @@ -15,20 +15,22 @@ import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; -import static java.util.Objects.requireNonNull; public class ByteArrayBlockBuilder implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(ByteArrayBlockBuilder.class).instanceSize() + BlockBuilderStatus.INSTANCE_SIZE; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ByteArrayBlockBuilder.class).instanceSize(); + @Nullable private BlockBuilderStatus blockBuilderStatus; private boolean initialized; private int initialEntryCount; @@ -39,11 +41,11 @@ public class ByteArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private byte[] values = new byte[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; - public ByteArrayBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public ByteArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { - this.blockBuilderStatus = requireNonNull(blockBuilderStatus, "blockBuilderStatus is null"); + this.blockBuilderStatus = blockBuilderStatus; this.initialEntryCount = max(expectedEntries, 1); updateDataSize(); @@ -59,7 +61,9 @@ public BlockBuilder writeByte(int value) values[positionCount] = (byte) value; positionCount++; - blockBuilderStatus.addBytes((Byte.BYTES + Byte.BYTES)); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes((Byte.BYTES + Byte.BYTES)); + } return this; } @@ -79,7 +83,9 @@ public BlockBuilder appendNull() valueIsNull[positionCount] = true; positionCount++; - blockBuilderStatus.addBytes((Byte.BYTES + Byte.BYTES)); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes((Byte.BYTES + Byte.BYTES)); + } return this; } @@ -113,28 +119,39 @@ private void growCapacity() private void updateDataSize() { - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + if (blockBuilderStatus != null) { + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; + } } // Copied from ByteArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) positionCount); + return (Byte.BYTES + Byte.BYTES) * positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Byte.BYTES + Byte.BYTES) * (long) length); + return (Byte.BYTES + Byte.BYTES) * length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java index 8b6b50b9c9b31..d9e5f97b12d2c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/DictionaryBlock.java @@ -22,12 +22,12 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; import static com.facebook.presto.spi.block.DictionaryId.randomDictionaryId; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.min; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class DictionaryBlock @@ -39,11 +39,16 @@ public class DictionaryBlock private final Block dictionary; private final int idsOffset; private final int[] ids; - private final int retainedSizeInBytes; - private volatile int sizeInBytes = -1; + private final long retainedSizeInBytes; + private volatile long sizeInBytes = -1; private volatile int uniqueIds = -1; private final DictionaryId dictionarySourceId; + public DictionaryBlock(Block dictionary, int[] ids) + { + this(requireNonNull(ids, "ids is null").length, dictionary, ids); + } + public DictionaryBlock(int positionCount, Block dictionary, int[] ids) { this(0, positionCount, dictionary, ids, false, randomDictionaryId()); @@ -82,7 +87,7 @@ private DictionaryBlock(int idsOffset, int positionCount, Block dictionary, int[ this.dictionary = dictionary; this.ids = ids; this.dictionarySourceId = requireNonNull(dictionarySourceId, "dictionarySourceId is null"); - this.retainedSizeInBytes = toIntExact(INSTANCE_SIZE + dictionary.getRetainedSizeInBytes() + sizeOf(ids)); + this.retainedSizeInBytes = INSTANCE_SIZE + dictionary.getRetainedSizeInBytes() + sizeOf(ids); if (dictionaryIsCompacted) { this.sizeInBytes = this.retainedSizeInBytes; @@ -187,7 +192,7 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { // this is racy but is safe because sizeInBytes is an int and the calculation is stable if (sizeInBytes < 0) { @@ -216,7 +221,7 @@ private void calculateCompactSize() } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { if (positionOffset == 0 && length == getPositionCount()) { // Calculation of getRegionSizeInBytes is expensive in this class. @@ -239,11 +244,19 @@ public int getRegionSizeInBytes(int positionOffset, int length) } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(dictionary, dictionary.getRetainedSizeInBytes()); + consumer.accept(ids, sizeOf(ids)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public BlockEncoding getEncoding() { @@ -267,7 +280,7 @@ public Block copyPositions(List positions) } newIds[i] = oldIndexToNewIndex.get(oldIndex); } - return new DictionaryBlock(positions.size(), dictionary.copyPositions(positionsToCopy), newIds); + return new DictionaryBlock(dictionary.copyPositions(positionsToCopy), newIds); } @Override @@ -286,7 +299,7 @@ public Block copyRegion(int position, int length) throw new IndexOutOfBoundsException("Invalid position " + position + " in block with " + positionCount + " positions"); } int[] newIds = Arrays.copyOfRange(ids, idsOffset + position, idsOffset + position + length); - DictionaryBlock dictionaryBlock = new DictionaryBlock(length, dictionary, newIds); + DictionaryBlock dictionaryBlock = new DictionaryBlock(dictionary, newIds); return dictionaryBlock.compact(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java index 27725058b3203..7462390e4a4ce 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlock.java @@ -19,9 +19,9 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static java.util.Objects.requireNonNull; public class FixedWidthBlock @@ -72,15 +72,23 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast(getRawSlice().length() + valueIsNull.length()); + return getRawSlice().length() + valueIsNull.length(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { - return intSaturatedCast(INSTANCE_SIZE + getRawSlice().getRetainedSize() + valueIsNull.getRetainedSize()); + return INSTANCE_SIZE + getRawSlice().getRetainedSize() + valueIsNull.getRetainedSize(); + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(slice, (long) slice.getRetainedSize()); + consumer.accept(valueIsNull, (long) valueIsNull.getRetainedSize()); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java index 9992c441248ff..e98bbc30398ac 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/FixedWidthBlockBuilder.java @@ -19,12 +19,14 @@ import io.airlift.slice.Slices; import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.MAX_ARRAY_SIZE; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; @@ -34,8 +36,9 @@ public class FixedWidthBlockBuilder extends AbstractFixedWidthBlock implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(FixedWidthBlockBuilder.class).instanceSize() + BlockBuilderStatus.INSTANCE_SIZE; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(FixedWidthBlockBuilder.class).instanceSize(); + @Nullable private BlockBuilderStatus blockBuilderStatus; private boolean initialized; @@ -47,7 +50,7 @@ public class FixedWidthBlockBuilder private int currentEntrySize; - public FixedWidthBlockBuilder(int fixedSize, BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public FixedWidthBlockBuilder(int fixedSize, @Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { super(fixedSize); @@ -82,15 +85,27 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast(sliceOutput.size() + valueIsNull.size()); + return sliceOutput.size() + valueIsNull.size(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { - return intSaturatedCast(INSTANCE_SIZE + sliceOutput.getRetainedSize() + valueIsNull.getRetainedSize()); + long size = INSTANCE_SIZE + sliceOutput.getRetainedSize() + valueIsNull.getRetainedSize(); + if (blockBuilderStatus != null) { + size += BlockBuilderStatus.INSTANCE_SIZE; + } + return size; + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(sliceOutput, (long) sliceOutput.getRetainedSize()); + consumer.accept(valueIsNull, (long) valueIsNull.getRetainedSize()); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override @@ -191,7 +206,9 @@ private void entryAdded(boolean isNull) valueIsNull.appendByte(isNull ? 1 : 0); positionCount++; - blockBuilderStatus.addBytes(Byte.BYTES + fixedSize); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Byte.BYTES + fixedSize); + } } private void checkCapacity() diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java index 2a2226fa32802..3456da04d6800 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class IntArrayBlock @@ -32,8 +32,8 @@ public class IntArrayBlock private final boolean[] valueIsNull; private final int[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public IntArrayBlock(int positionCount, boolean[] valueIsNull, int[] values) { @@ -61,28 +61,36 @@ public IntArrayBlock(int positionCount, boolean[] valueIsNull, int[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + sizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) length); + return (Integer.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java index e8f4ea9f9bd83..fcb05e279f400 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/IntArrayBlockBuilder.java @@ -15,21 +15,23 @@ import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; -import static java.util.Objects.requireNonNull; public class IntArrayBlockBuilder implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(IntArrayBlockBuilder.class).instanceSize() + BlockBuilderStatus.INSTANCE_SIZE; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(IntArrayBlockBuilder.class).instanceSize(); + @Nullable private BlockBuilderStatus blockBuilderStatus; private boolean initialized; private int initialEntryCount; @@ -40,11 +42,11 @@ public class IntArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private int[] values = new int[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; - public IntArrayBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public IntArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { - this.blockBuilderStatus = requireNonNull(blockBuilderStatus, "blockBuilderStatus is null"); + this.blockBuilderStatus = blockBuilderStatus; this.initialEntryCount = max(expectedEntries, 1); updateDataSize(); @@ -60,7 +62,9 @@ public BlockBuilder writeInt(int value) values[positionCount] = value; positionCount++; - blockBuilderStatus.addBytes(Byte.BYTES + Integer.BYTES); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Byte.BYTES + Integer.BYTES); + } return this; } @@ -80,7 +84,9 @@ public BlockBuilder appendNull() valueIsNull[positionCount] = true; positionCount++; - blockBuilderStatus.addBytes(Byte.BYTES + Integer.BYTES); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Byte.BYTES + Integer.BYTES); + } return this; } @@ -114,28 +120,39 @@ private void growCapacity() private void updateDataSize() { - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + if (blockBuilderStatus != null) { + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; + } } // Copied from IntArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) positionCount); + return (Integer.BYTES + Byte.BYTES) * positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Integer.BYTES + Byte.BYTES) * (long) length); + return (Integer.BYTES + Byte.BYTES) * length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java index 1978c4633f38e..c021733b4a1fb 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlock.java @@ -15,7 +15,8 @@ import org.openjdk.jol.info.ClassLayout; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; public class InterleavedBlock extends AbstractInterleavedBlock @@ -26,17 +27,17 @@ public class InterleavedBlock private final InterleavedBlockEncoding blockEncoding; private final int start; private final int positionCount; - private final int retainedSizeInBytes; + private final long retainedSizeInBytes; - private final AtomicInteger sizeInBytes; + private final AtomicLong sizeInBytes; public InterleavedBlock(Block[] blocks) { super(blocks.length); this.blocks = blocks; - int sizeInBytes = 0; - int retainedSizeInBytes = INSTANCE_SIZE; + long sizeInBytes = 0; + long retainedSizeInBytes = INSTANCE_SIZE; int positionCount = 0; int firstSubBlockPositionCount = blocks[0].getPositionCount(); for (int i = 0; i < getBlockCount(); i++) { @@ -52,11 +53,11 @@ public InterleavedBlock(Block[] blocks) this.blockEncoding = computeBlockEncoding(); this.start = 0; this.positionCount = positionCount; - this.sizeInBytes = new AtomicInteger(sizeInBytes); + this.sizeInBytes = new AtomicLong(sizeInBytes); this.retainedSizeInBytes = retainedSizeInBytes; } - private InterleavedBlock(Block[] blocks, int start, int positionCount, int retainedSizeInBytes, InterleavedBlockEncoding blockEncoding) + private InterleavedBlock(Block[] blocks, int start, int positionCount, long retainedSizeInBytes, InterleavedBlockEncoding blockEncoding) { super(blocks.length); this.blocks = blocks; @@ -64,7 +65,7 @@ private InterleavedBlock(Block[] blocks, int start, int positionCount, int retai this.positionCount = positionCount; this.retainedSizeInBytes = retainedSizeInBytes; this.blockEncoding = blockEncoding; - this.sizeInBytes = new AtomicInteger(-1); + this.sizeInBytes = new AtomicLong(-1); } @Override @@ -103,9 +104,9 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - int sizeInBytes = this.sizeInBytes.get(); + long sizeInBytes = this.sizeInBytes.get(); if (sizeInBytes < 0) { sizeInBytes = 0; for (int i = 0; i < getBlockCount(); i++) { @@ -117,11 +118,18 @@ public int getSizeInBytes() } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(blocks, retainedSizeInBytes - INSTANCE_SIZE); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public String toString() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java index 6da6ffc70f0ec..32ff7c95b9d8c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/InterleavedBlockBuilder.java @@ -18,6 +18,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static java.util.Objects.requireNonNull; @@ -33,10 +34,10 @@ public class InterleavedBlockBuilder private int positionCount; private int currentBlockIndex; - private int sizeInBytes; - private int startSize; - private int retainedSizeInBytes; - private int startRetainedSize; + private long sizeInBytes; + private long startSize; + private long retainedSizeInBytes; + private long startRetainedSize; public InterleavedBlockBuilder(List types, BlockBuilderStatus blockBuilderStatus, int expectedEntries) { @@ -109,17 +110,24 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(blockBuilders, retainedSizeInBytes - INSTANCE_SIZE); + consumer.accept(this, (long) INSTANCE_SIZE); + } + private void recordStartSizesIfNecessary(BlockBuilder blockBuilder) { if (startSize < 0) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java index 10d3142e395ce..7ac0b66ab43bc 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/LazyBlock.java @@ -17,6 +17,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static java.util.Objects.requireNonNull; @@ -164,26 +165,36 @@ public Block getSingleValueBlock(int position) } @Override - public int getSizeInBytes() + public long getSizeInBytes() { assureLoaded(); return block.getSizeInBytes(); } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { assureLoaded(); return block.getRegionSizeInBytes(position, length); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { assureLoaded(); return INSTANCE_SIZE + block.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + // do not support LazyBlock (for now) for the following two reasons: + // (1) the method is mainly used for inspecting the identity and size of each element to prevent over counting + // (2) the method should be non-recursive and only inspects blocks at the top level; + // given LazyBlock is a wrapper for other blocks, it is not meaningful to only inspect the top-level elements + throw new UnsupportedOperationException(getClass().getName()); + } + @Override public BlockEncoding getEncoding() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java index ef29a44d0899f..17e0ff44ffd2a 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.toIntExact; @@ -33,8 +33,8 @@ public class LongArrayBlock private final boolean[] valueIsNull; private final long[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public LongArrayBlock(int positionCount, boolean[] valueIsNull, long[] values) { @@ -62,28 +62,36 @@ public LongArrayBlock(int positionCount, boolean[] valueIsNull, long[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + sizeInBytes = (Long.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) length); + return (Long.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java index 89894dd14e548..99aa6b7e0d270 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/LongArrayBlockBuilder.java @@ -15,22 +15,24 @@ import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; public class LongArrayBlockBuilder implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(LongArrayBlockBuilder.class).instanceSize() + BlockBuilderStatus.INSTANCE_SIZE; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(LongArrayBlockBuilder.class).instanceSize(); + @Nullable private BlockBuilderStatus blockBuilderStatus; private boolean initialized; private int initialEntryCount; @@ -41,11 +43,11 @@ public class LongArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private long[] values = new long[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; - public LongArrayBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public LongArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { - this.blockBuilderStatus = requireNonNull(blockBuilderStatus, "blockBuilderStatus is null"); + this.blockBuilderStatus = blockBuilderStatus; this.initialEntryCount = max(expectedEntries, 1); updateDataSize(); @@ -61,7 +63,9 @@ public BlockBuilder writeLong(long value) values[positionCount] = value; positionCount++; - blockBuilderStatus.addBytes(Byte.BYTES + Long.BYTES); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Byte.BYTES + Long.BYTES); + } return this; } @@ -81,7 +85,9 @@ public BlockBuilder appendNull() valueIsNull[positionCount] = true; positionCount++; - blockBuilderStatus.addBytes(Byte.BYTES + Long.BYTES); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Byte.BYTES + Long.BYTES); + } return this; } @@ -115,28 +121,39 @@ private void growCapacity() private void updateDataSize() { - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + if (blockBuilderStatus != null) { + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; + } } // Copied from LongArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) positionCount); + return (Long.BYTES + Byte.BYTES) * positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Long.BYTES + Byte.BYTES) * (long) length); + return (Long.BYTES + Byte.BYTES) * length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java new file mode 100644 index 0000000000000..ad92074709881 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlock.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.Type; +import org.openjdk.jol.info.ClassLayout; + +import java.lang.invoke.MethodHandle; +import java.util.Arrays; +import java.util.function.BiConsumer; + +import static com.facebook.presto.spi.block.MapBlockBuilder.buildHashTable; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class MapBlock + extends AbstractMapBlock +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(MapBlock.class).instanceSize(); + + private final int startOffset; + private final int positionCount; + + private final boolean[] mapIsNull; + private final int[] offsets; + private final Block keyBlock; + private final Block valueBlock; + private final int[] hashTables; // hash to location in map; + + private long sizeInBytes; + private final long retainedSizeInBytes; + + /** + * @param keyBlockNativeEquals (T, Block, int)boolean + * @param keyNativeHashCode (T)long + */ + MapBlock( + int startOffset, + int positionCount, + boolean[] mapIsNull, + int[] offsets, + Block keyBlock, + Block valueBlock, + int[] hashTables, + Type keyType, + MethodHandle keyBlockNativeEquals, + MethodHandle keyNativeHashCode) + { + super(keyType, keyNativeHashCode, keyBlockNativeEquals); + + this.startOffset = startOffset; + this.positionCount = positionCount; + this.mapIsNull = mapIsNull; + this.offsets = requireNonNull(offsets, "offsets is null"); + this.keyBlock = requireNonNull(keyBlock, "keyBlock is null"); + this.valueBlock = requireNonNull(valueBlock, "valueBlock is null"); + if (keyBlock.getPositionCount() != valueBlock.getPositionCount()) { + throw new IllegalArgumentException(format("keyBlock and valueBlock has different size: %s %s", keyBlock.getPositionCount(), valueBlock.getPositionCount())); + } + if (hashTables.length < keyBlock.getPositionCount() * HASH_MULTIPLIER) { + throw new IllegalArgumentException(format("keyBlock/valueBlock size does not match hash table size: %s %s", keyBlock.getPositionCount(), hashTables.length)); + } + this.hashTables = hashTables; + + this.sizeInBytes = -1; + this.retainedSizeInBytes = INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(mapIsNull) + sizeOf(hashTables); + } + + @Override + protected Block getKeys() + { + return keyBlock; + } + + @Override + protected Block getValues() + { + return valueBlock; + } + + @Override + protected int[] getHashTables() + { + return hashTables; + } + + @Override + protected int[] getOffsets() + { + return offsets; + } + + @Override + protected int getOffsetBase() + { + return startOffset; + } + + @Override + protected boolean[] getMapIsNull() + { + return mapIsNull; + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public long getSizeInBytes() + { + // this is racy but is safe because sizeInBytes is an int and the calculation is stable + if (sizeInBytes < 0) { + calculateSize(); + } + return sizeInBytes; + } + + private void calculateSize() + { + int entriesStart = offsets[startOffset]; + int entriesEnd = offsets[startOffset + positionCount]; + int entryCount = entriesEnd - entriesStart; + sizeInBytes = keyBlock.getRegionSizeInBytes(entriesStart, entryCount) + + valueBlock.getRegionSizeInBytes(entriesStart, entryCount) + + (Integer.BYTES + Byte.BYTES) * this.positionCount + + Integer.BYTES * HASH_MULTIPLIER * entryCount; + } + + @Override + public long getRetainedSizeInBytes() + { + return retainedSizeInBytes; + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlock, keyBlock.getRetainedSizeInBytes()); + consumer.accept(valueBlock, valueBlock.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(mapIsNull, sizeOf(mapIsNull)); + consumer.accept(hashTables, sizeOf(hashTables)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + + @Override + public String toString() + { + StringBuilder sb = new StringBuilder("MapBlock{"); + sb.append("positionCount=").append(getPositionCount()); + sb.append('}'); + return sb.toString(); + } + + public static MapBlock fromKeyValueBlock( + boolean useNewMapBlock, + boolean[] mapIsNull, + int[] offsets, + Block keyBlock, + Block valueBlock, + MapType mapType, + MethodHandle keyBlockNativeEquals, + MethodHandle keyNativeHashCode, + MethodHandle keyBlockHashCode) + { + if (keyBlock.getPositionCount() != valueBlock.getPositionCount()) { + throw new IllegalArgumentException(format("keyBlock position count does not match valueBlock position count. %s %s", keyBlock.getPositionCount(), valueBlock.getPositionCount())); + } + int elementCount = keyBlock.getPositionCount(); + if (mapIsNull.length != offsets.length - 1) { + throw new IllegalArgumentException(format("mapIsNull.length-1 does not match offsets.length. %s %s", mapIsNull.length - 1, offsets.length)); + } + int mapCount = mapIsNull.length; + if (offsets[mapCount] != elementCount) { + throw new IllegalArgumentException(format("Last element of offsets does not match keyBlock position count. %s %s", offsets[mapCount], keyBlock.getPositionCount())); + } + int[] hashTables = new int[elementCount * HASH_MULTIPLIER]; + Arrays.fill(hashTables, -1); + for (int i = 0; i < mapCount; i++) { + int keyOffset = offsets[i]; + int keyCount = offsets[i + 1] - keyOffset; + if (keyCount < 0) { + throw new IllegalArgumentException(format("Offset is not monotonically ascending. offsets[%s]=%s, offsets[%s]=%s", i, offsets[i], i + 1, offsets[i + 1])); + } + buildHashTable(useNewMapBlock, keyBlock, keyOffset, keyCount, keyBlockHashCode, hashTables, keyOffset * HASH_MULTIPLIER, keyCount * HASH_MULTIPLIER); + } + + return new MapBlock( + 0, + mapCount, + mapIsNull, + offsets, + keyBlock, + valueBlock, + hashTables, + mapType.getKeyType(), + keyBlockNativeEquals, + keyNativeHashCode); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java new file mode 100644 index 0000000000000..8f1bf513384d2 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockBuilder.java @@ -0,0 +1,369 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.type.Type; +import org.openjdk.jol.info.ClassLayout; + +import javax.annotation.Nullable; + +import java.lang.invoke.MethodHandle; +import java.util.Arrays; +import java.util.function.BiConsumer; + +import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class MapBlockBuilder + extends AbstractMapBlock + implements BlockBuilder +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(MapBlockBuilder.class).instanceSize(); + + private final boolean useNewMapBlock; + private final MethodHandle keyBlockHashCode; + + @Nullable + private BlockBuilderStatus blockBuilderStatus; + + private int positionCount; + private int[] offsets; + private boolean[] mapIsNull; + private final BlockBuilder keyBlockBuilder; + private final BlockBuilder valueBlockBuilder; + private int[] hashTables; + + private boolean currentEntryOpened; + + public MapBlockBuilder( + boolean useNewMapBlock, + Type keyType, + Type valueType, + MethodHandle keyBlockNativeEquals, + MethodHandle keyNativeHashCode, + MethodHandle keyBlockHashCode, + BlockBuilderStatus blockBuilderStatus, + int expectedEntries) + { + this( + useNewMapBlock, + keyType, + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode, + blockBuilderStatus, + keyType.createBlockBuilder(blockBuilderStatus, expectedEntries), + valueType.createBlockBuilder(blockBuilderStatus, expectedEntries), + new int[expectedEntries + 1], + new boolean[expectedEntries], + newNegativeOneFilledArray(expectedEntries * HASH_MULTIPLIER)); + } + + private MapBlockBuilder( + boolean useNewMapBlock, + Type keyType, + MethodHandle keyBlockNativeEquals, + MethodHandle keyNativeHashCode, + MethodHandle keyBlockHashCode, + @Nullable BlockBuilderStatus blockBuilderStatus, + BlockBuilder keyBlockBuilder, + BlockBuilder valueBlockBuilder, + int[] offsets, + boolean[] mapIsNull, + int[] hashTables) + { + super(keyType, keyNativeHashCode, keyBlockNativeEquals); + + this.useNewMapBlock = useNewMapBlock; + if (useNewMapBlock) { + requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); + } + else { + if (keyBlockHashCode != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyBlockHashCode should be null."); + } + } + this.keyBlockHashCode = keyBlockHashCode; + this.blockBuilderStatus = blockBuilderStatus; + + this.positionCount = 0; + this.offsets = requireNonNull(offsets, "offsets is null"); + this.mapIsNull = requireNonNull(mapIsNull, "mapIsNull is null"); + this.keyBlockBuilder = requireNonNull(keyBlockBuilder, "keyBlockBuilder is null"); + this.valueBlockBuilder = requireNonNull(valueBlockBuilder, "valueBlockBuilder is null"); + this.hashTables = requireNonNull(hashTables, "hashTables is null"); + } + + @Override + protected Block getKeys() + { + return keyBlockBuilder; + } + + @Override + protected Block getValues() + { + return valueBlockBuilder; + } + + @Override + protected int[] getHashTables() + { + return hashTables; + } + + @Override + protected int[] getOffsets() + { + return offsets; + } + + @Override + protected int getOffsetBase() + { + return 0; + } + + @Override + protected boolean[] getMapIsNull() + { + return mapIsNull; + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public long getSizeInBytes() + { + return keyBlockBuilder.getSizeInBytes() + valueBlockBuilder.getSizeInBytes() + + (Integer.BYTES + Byte.BYTES) * positionCount + + Integer.BYTES * HASH_MULTIPLIER * keyBlockBuilder.getPositionCount(); + } + + @Override + public long getRetainedSizeInBytes() + { + long size = INSTANCE_SIZE + keyBlockBuilder.getRetainedSizeInBytes() + valueBlockBuilder.getRetainedSizeInBytes() + sizeOf(offsets) + sizeOf(mapIsNull) + sizeOf(hashTables); + if (blockBuilderStatus != null) { + size += BlockBuilderStatus.INSTANCE_SIZE; + } + return size; + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlockBuilder, keyBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(valueBlockBuilder, valueBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(mapIsNull, sizeOf(mapIsNull)); + consumer.accept(hashTables, sizeOf(hashTables)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + + @Override + public SingleMapBlockWriter beginBlockEntry() + { + if (currentEntryOpened) { + throw new IllegalStateException("Expected current entry to be closed but was opened"); + } + currentEntryOpened = true; + return new SingleMapBlockWriter(keyBlockBuilder.getPositionCount() * 2, keyBlockBuilder, valueBlockBuilder); + } + + @Override + public BlockBuilder closeEntry() + { + if (!currentEntryOpened) { + throw new IllegalStateException("Expected entry to be opened but was closed"); + } + + entryAdded(false); + currentEntryOpened = false; + + int previousAggregatedEntryCount = offsets[positionCount - 1]; + int aggregatedEntryCount = offsets[positionCount]; + int entryCount = aggregatedEntryCount - previousAggregatedEntryCount; + if (hashTables.length < aggregatedEntryCount * HASH_MULTIPLIER) { + int newSize = BlockUtil.calculateNewArraySize(aggregatedEntryCount * HASH_MULTIPLIER); + int oldSize = hashTables.length; + hashTables = Arrays.copyOf(hashTables, newSize); + Arrays.fill(hashTables, oldSize, hashTables.length, -1); + } + buildHashTable(useNewMapBlock, keyBlockBuilder, previousAggregatedEntryCount, entryCount, keyBlockHashCode, hashTables, previousAggregatedEntryCount * HASH_MULTIPLIER, entryCount * HASH_MULTIPLIER); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(entryCount * HASH_MULTIPLIER * Integer.BYTES); + } + + return this; + } + + @Override + public BlockBuilder appendNull() + { + if (currentEntryOpened) { + throw new IllegalStateException("Current entry must be closed before a null can be written"); + } + + entryAdded(true); + return this; + } + + private void entryAdded(boolean isNull) + { + if (keyBlockBuilder.getPositionCount() != valueBlockBuilder.getPositionCount()) { + throw new IllegalStateException(format("keyBlock and valueBlock has different size: %s %s", keyBlockBuilder.getPositionCount(), valueBlockBuilder.getPositionCount())); + } + if (mapIsNull.length <= positionCount) { + int newSize = BlockUtil.calculateNewArraySize(mapIsNull.length); + mapIsNull = Arrays.copyOf(mapIsNull, newSize); + offsets = Arrays.copyOf(offsets, newSize + 1); + } + offsets[positionCount + 1] = keyBlockBuilder.getPositionCount(); + mapIsNull[positionCount] = isNull; + positionCount++; + + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Integer.BYTES + Byte.BYTES); + } + } + + @Override + public Block build() + { + if (currentEntryOpened) { + throw new IllegalStateException("Current entry must be closed before the block can be built"); + } + return new MapBlock( + 0, + positionCount, + mapIsNull, + offsets, + keyBlockBuilder.build(), + valueBlockBuilder.build(), + Arrays.copyOf(hashTables, offsets[positionCount] * HASH_MULTIPLIER), + keyType, + keyBlockNativeEquals, keyNativeHashCode + ); + } + + @Override + public String toString() + { + return "MapBlockBuilder{" + + "positionCount=" + getPositionCount() + + '}'; + } + + @Override + public BlockBuilder writeObject(Object value) + { + if (currentEntryOpened) { + throw new IllegalStateException("Expected current entry to be closed but was opened"); + } + currentEntryOpened = true; + + Block block = (Block) value; + int blockPositionCount = block.getPositionCount(); + if (blockPositionCount % 2 != 0) { + throw new IllegalArgumentException(format("block position count is not even: %s", blockPositionCount)); + } + for (int i = 0; i < blockPositionCount; i += 2) { + if (block.isNull(i)) { + throw new IllegalArgumentException("Map keys must not be null"); + } + else { + block.writePositionTo(i, keyBlockBuilder); + keyBlockBuilder.closeEntry(); + } + if (block.isNull(i + 1)) { + valueBlockBuilder.appendNull(); + } + else { + block.writePositionTo(i + 1, valueBlockBuilder); + valueBlockBuilder.closeEntry(); + } + } + return this; + } + + @Override + public BlockBuilder newBlockBuilderLike(BlockBuilderStatus blockBuilderStatus) + { + int newSize = calculateBlockResetSize(getPositionCount()); + return new MapBlockBuilder( + useNewMapBlock, + keyType, + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode, + blockBuilderStatus, + keyBlockBuilder.newBlockBuilderLike(blockBuilderStatus), + valueBlockBuilder.newBlockBuilderLike(blockBuilderStatus), + new int[newSize + 1], + new boolean[newSize], + newNegativeOneFilledArray(newSize * HASH_MULTIPLIER)); + } + + private static int[] newNegativeOneFilledArray(int size) + { + int[] hashTable = new int[size]; + Arrays.fill(hashTable, -1); + return hashTable; + } + + static void buildHashTable(boolean useNewMapBlock, Block keyBlock, int keyOffset, int keyCount, MethodHandle keyBlockHashCode, int[] outputHashTable, int hashTableOffset, int hashTableSize) + { + if (!useNewMapBlock) { + return; + } + + // This method assumes that keyBlock has no duplicated entries (in the specified range) + for (int i = 0; i < keyCount; i++) { + if (keyBlock.isNull(keyOffset + i)) { + throw new IllegalArgumentException("map keys cannot be null"); + } + + long hashCode; + try { + hashCode = (long) keyBlockHashCode.invokeExact(keyBlock, keyOffset + i); + } + catch (Throwable throwable) { + if (throwable instanceof RuntimeException) { + throw (RuntimeException) throwable; + } + throw new RuntimeException(throwable); + } + + int hash = (int) Math.floorMod(hashCode, hashTableSize); + while (true) { + if (outputHashTable[hashTableOffset + hash] == -1) { + outputHashTable[hashTableOffset + hash] = i; + break; + } + hash++; + if (hash == hashTableSize) { + hash = 0; + } + } + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java new file mode 100644 index 0000000000000..4bb06745fc5f7 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MapBlockEncoding.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSerde; +import io.airlift.slice.SliceInput; +import io.airlift.slice.SliceOutput; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.spi.block.AbstractMapBlock.HASH_MULTIPLIER; +import static com.facebook.presto.spi.block.MethodHandleUtil.compose; +import static com.facebook.presto.spi.block.MethodHandleUtil.nativeValueGetter; +import static io.airlift.slice.Slices.wrappedIntArray; +import static java.lang.String.format; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; + +public class MapBlockEncoding + implements BlockEncoding +{ + public static final BlockEncodingFactory FACTORY = new MapBlockEncodingFactory(); + private static final String NAME = "MAP"; + + private final Type keyType; + private final MethodHandle keyNativeHashCode; + private final MethodHandle keyBlockNativeEquals; + private final BlockEncoding keyBlockEncoding; + private final BlockEncoding valueBlockEncoding; + + public MapBlockEncoding(Type keyType, MethodHandle keyBlockNativeEquals, MethodHandle keyNativeHashCode, BlockEncoding keyBlockEncoding, BlockEncoding valueBlockEncoding) + { + this.keyType = requireNonNull(keyType, "keyType is null"); + // keyNativeHashCode can only be null due to map block kill switch. deprecated.new-map-block + this.keyNativeHashCode = keyNativeHashCode; + // keyBlockNativeEquals can only be null due to map block kill switch. deprecated.new-map-block + this.keyBlockNativeEquals = keyBlockNativeEquals; + this.keyBlockEncoding = requireNonNull(keyBlockEncoding, "keyBlockEncoding is null"); + this.valueBlockEncoding = requireNonNull(valueBlockEncoding, "valueBlockEncoding is null"); + } + + @Override + public String getName() + { + return NAME; + } + + @Override + public void writeBlock(SliceOutput sliceOutput, Block block) + { + AbstractMapBlock mapBlock = (AbstractMapBlock) block; + + int positionCount = mapBlock.getPositionCount(); + + int offsetBase = mapBlock.getOffsetBase(); + int[] offsets = mapBlock.getOffsets(); + int[] hashTable = mapBlock.getHashTables(); + + int entriesStartOffset = offsets[offsetBase]; + int entriesEndOffset = offsets[offsetBase + positionCount]; + keyBlockEncoding.writeBlock(sliceOutput, mapBlock.getKeys().getRegion(entriesStartOffset, entriesEndOffset - entriesStartOffset)); + valueBlockEncoding.writeBlock(sliceOutput, mapBlock.getValues().getRegion(entriesStartOffset, entriesEndOffset - entriesStartOffset)); + + sliceOutput.appendInt((entriesEndOffset - entriesStartOffset) * HASH_MULTIPLIER); + sliceOutput.writeBytes(wrappedIntArray(hashTable, entriesStartOffset * HASH_MULTIPLIER, (entriesEndOffset - entriesStartOffset) * HASH_MULTIPLIER)); + + sliceOutput.appendInt(positionCount); + for (int position = 0; position < positionCount + 1; position++) { + sliceOutput.writeInt(offsets[offsetBase + position] - entriesStartOffset); + } + EncoderUtil.encodeNullsAsBits(sliceOutput, block); + } + + @Override + public Block readBlock(SliceInput sliceInput) + { + Block keyBlock = keyBlockEncoding.readBlock(sliceInput); + Block valueBlock = valueBlockEncoding.readBlock(sliceInput); + + int[] hashTable = new int[sliceInput.readInt()]; + sliceInput.readBytes(wrappedIntArray(hashTable)); + + if (keyBlock.getPositionCount() != valueBlock.getPositionCount() || keyBlock.getPositionCount() * HASH_MULTIPLIER != hashTable.length) { + throw new IllegalArgumentException( + format("Deserialized MapBlock violates invariants: key %d, value %d, hash %d", keyBlock.getPositionCount(), valueBlock.getPositionCount(), hashTable.length)); + } + + int positionCount = sliceInput.readInt(); + int[] offsets = new int[positionCount + 1]; + sliceInput.readBytes(wrappedIntArray(offsets)); + boolean[] mapIsNull = EncoderUtil.decodeNullBits(sliceInput, positionCount); + return new MapBlock(0, positionCount, mapIsNull, offsets, keyBlock, valueBlock, hashTable, keyType, keyBlockNativeEquals, keyNativeHashCode); + } + + @Override + public BlockEncodingFactory getFactory() + { + return FACTORY; + } + + public static class MapBlockEncodingFactory + implements BlockEncodingFactory + { + @Override + public String getName() + { + return NAME; + } + + @Override + public MapBlockEncoding readEncoding(TypeManager typeManager, BlockEncodingSerde serde, SliceInput input) + { + Type keyType = TypeSerde.readType(typeManager, input); + MethodHandle keyNativeEquals = typeManager.resolveOperator(OperatorType.EQUAL, asList(keyType, keyType)); + MethodHandle keyBlockNativeEquals = compose(keyNativeEquals, nativeValueGetter(keyType)); + MethodHandle keyNativeHashCode = typeManager.resolveOperator(OperatorType.HASH_CODE, singletonList(keyType)); + + BlockEncoding keyBlockEncoding = serde.readBlockEncoding(input); + BlockEncoding valueBlockEncoding = serde.readBlockEncoding(input); + return new MapBlockEncoding(keyType, keyBlockNativeEquals, keyNativeHashCode, keyBlockEncoding, valueBlockEncoding); + } + + @Override + public void writeEncoding(BlockEncodingSerde serde, SliceOutput output, MapBlockEncoding blockEncoding) + { + TypeSerde.writeType(output, blockEncoding.keyType); + serde.writeBlockEncoding(output, blockEncoding.keyBlockEncoding); + serde.writeBlockEncoding(output, blockEncoding.valueBlockEncoding); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/MethodHandleUtil.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/MethodHandleUtil.java new file mode 100644 index 0000000000000..aaa0fc23b7eaf --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/MethodHandleUtil.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.type.Type; +import io.airlift.slice.Slice; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; + +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.lang.invoke.MethodType.methodType; + +public final class MethodHandleUtil +{ + private static final MethodHandle GET_LONG = methodHandle(Type.class, "getLong", Block.class, int.class); + private static final MethodHandle GET_DOUBLE = methodHandle(Type.class, "getDouble", Block.class, int.class); + private static final MethodHandle GET_BOOLEAN = methodHandle(Type.class, "getBoolean", Block.class, int.class); + private static final MethodHandle GET_SLICE = methodHandle(Type.class, "getSlice", Block.class, int.class); + private static final MethodHandle GET_BLOCK = methodHandle(Type.class, "getObject", Block.class, int.class).asType(methodType(Block.class, Type.class, Block.class, int.class)); + private static final MethodHandle GET_UNKNOWN = methodHandle(MethodHandleUtil.class, "unknownGetter", Type.class, Block.class, int.class); + + private MethodHandleUtil() + { + } + + /** + * @param f (U, S1, S2, ..., Sm)R + * @param g (T1, T2, ..., Tn)U + * @return (T1, T2, ..., Tn, S1, S2, ..., Sm)R + */ + public static MethodHandle compose(MethodHandle f, MethodHandle g) + { + if (f.type().parameterType(0) != g.type().returnType()) { + throw new IllegalArgumentException(String.format("f.parameter(0) != g.return(). f: %s g: %s", f.type(), g.type())); + } + // Semantics: f => f + // Type: (U, S1, S2, ..., Sn)R => (U, T1, T2, ..., Tm, S1, S2, ..., Sn)R + MethodHandle fUTS = MethodHandles.dropArguments(f, 1, g.type().parameterList()); + // Semantics: f => fg + // Type: (U, T1, T2, ..., Tm, S1, S2, ..., Sn)R => (T1, T2, ..., Tm, S1, S2, ..., Sn)R + return MethodHandles.foldArguments(fUTS, g); + } + + /** + * @param f (U, V)R + * @param g (S1, S2, ..., Sm)U + * @param h (T1, T2, ..., Tn)V + * @return (S1, S2, ..., Sm, T1, T2, ..., Tn)R + */ + public static MethodHandle compose(MethodHandle f, MethodHandle g, MethodHandle h) + { + if (f.type().parameterCount() != 2) { + throw new IllegalArgumentException(String.format("f.parameterCount != 2. f: %s", f.type())); + } + if (f.type().parameterType(0) != g.type().returnType()) { + throw new IllegalArgumentException(String.format("f.parameter(0) != g.return. f: %s g: %s", f.type(), g.type())); + } + if (f.type().parameterType(1) != h.type().returnType()) { + throw new IllegalArgumentException(String.format("f.parameter(0) != h.return. f: %s h: %s", f.type(), h.type())); + } + + // (V, T1, T2, ..., Tn, U)R + MethodType typeVTU = f.type().dropParameterTypes(0, 1).appendParameterTypes(h.type().parameterList()).appendParameterTypes(f.type().parameterType(0)); + // Semantics: f => f + // Type: (U, V)R => (V, T1, T2, ..., Tn, U)R + MethodHandle fVTU = MethodHandles.permuteArguments(f, typeVTU, h.type().parameterCount() + 1, 0); + // Semantics: f => fh + // Type: (V, T1, T2, ..., Tn, U)R => (T1, T2, ..., Tn, U)R + MethodHandle fhTU = MethodHandles.foldArguments(fVTU, h); + + // reorder: [m+1, m+2, ..., m+n, 0] + int[] reorder = new int[fhTU.type().parameterCount()]; + for (int i = 0; i < reorder.length - 1; i++) { + reorder[i] = i + 1 + g.type().parameterCount(); + } + reorder[reorder.length - 1] = 0; + + // (U, S1, S2, ..., Sm, T1, T2, ..., Tn)R + MethodType typeUST = f.type().dropParameterTypes(1, 2).appendParameterTypes(g.type().parameterList()).appendParameterTypes(h.type().parameterList()); + // Semantics: f.h => f.h + // Type: (T1, T2, ..., Tn, U)R => (U, S1, S2, ..., Sm, T1, T2, ..., Tn)R + MethodHandle fhUST = MethodHandles.permuteArguments(fhTU, typeUST, reorder); + + // Semantics: fh => fgh + // Type: (U, S1, S2, ..., Sm, T1, T2, ..., Tn)R => (S1, S2, ..., Sm, T1, T2, ..., Tn)R + return MethodHandles.foldArguments(fhUST, g); + } + + /** + * Returns a MethodHandle corresponding to the specified method. + *

    + * Warning: The way Oracle JVM implements producing MethodHandle for a method involves creating + * JNI global weak references. G1 processes such references serially. As a result, calling this + * method in a tight loop can create significant GC pressure and significantly increase + * application pause time. + */ + public static MethodHandle methodHandle(Class clazz, String name, Class... parameterTypes) + { + try { + return MethodHandles.lookup().unreflect(clazz.getMethod(name, parameterTypes)); + } + catch (IllegalAccessException | NoSuchMethodException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, e); + } + } + + public static MethodHandle nativeValueGetter(Type type) + { + Class javaType = type.getJavaType(); + + MethodHandle methodHandle; + if (javaType == long.class) { + methodHandle = GET_LONG; + } + else if (javaType == double.class) { + methodHandle = GET_DOUBLE; + } + else if (javaType == boolean.class) { + methodHandle = GET_BOOLEAN; + } + else if (javaType == Slice.class) { + methodHandle = GET_SLICE; + } + else if (javaType == Block.class) { + methodHandle = GET_BLOCK; + } + else if (javaType == void.class) { + methodHandle = GET_UNKNOWN; + } + else { + throw new IllegalArgumentException("Unknown java type " + javaType + " from type " + type); + } + + return methodHandle.bindTo(type); + } + + public static Void unknownGetter(Type type, Block block, int position) + { + throw new IllegalArgumentException("For UNKNOWN type, getter should never be invoked on Block"); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java index c1f2065982558..87b4d9a9e59ca 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/RunLengthEncodedBlock.java @@ -19,6 +19,7 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; import static java.lang.String.format; @@ -70,17 +71,24 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return value.getSizeInBytes(); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return INSTANCE_SIZE + value.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(value, value.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public RunLengthBlockEncoding getEncoding() { @@ -102,7 +110,7 @@ public Block getRegion(int positionOffset, int length) } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { return value.getSizeInBytes(); } @@ -177,7 +185,7 @@ public void writeBytesTo(int position, int offset, int length, BlockBuilder bloc @Override public void writePositionTo(int position, BlockBuilder blockBuilder) { - value.writePositionTo(position, blockBuilder); + value.writePositionTo(0, blockBuilder); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java index 62c8b3676b1f7..1a5aa29fe1d12 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlock.java @@ -17,9 +17,9 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class ShortArrayBlock @@ -32,8 +32,8 @@ public class ShortArrayBlock private final boolean[] valueIsNull; private final short[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public ShortArrayBlock(int positionCount, boolean[] valueIsNull, short[] values) { @@ -61,28 +61,36 @@ public ShortArrayBlock(int positionCount, boolean[] valueIsNull, short[] values) } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) positionCount); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + sizeInBytes = (Short.BYTES + Byte.BYTES) * (long) positionCount; + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) length); + return (Short.BYTES + Byte.BYTES) * (long) length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java index 78effb3f8e588..58d8aacda82f4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ShortArrayBlockBuilder.java @@ -15,21 +15,23 @@ import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; import static java.lang.Math.max; -import static java.util.Objects.requireNonNull; public class ShortArrayBlockBuilder implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(ShortArrayBlockBuilder.class).instanceSize() + BlockBuilderStatus.INSTANCE_SIZE; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ShortArrayBlockBuilder.class).instanceSize(); + @Nullable private BlockBuilderStatus blockBuilderStatus; private boolean initialized; private int initialEntryCount; @@ -40,11 +42,11 @@ public class ShortArrayBlockBuilder private boolean[] valueIsNull = new boolean[0]; private short[] values = new short[0]; - private int retainedSizeInBytes; + private long retainedSizeInBytes; - public ShortArrayBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + public ShortArrayBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries) { - this.blockBuilderStatus = requireNonNull(blockBuilderStatus, "blockBuilderStatus is null"); + this.blockBuilderStatus = blockBuilderStatus; this.initialEntryCount = max(expectedEntries, 1); updateDataSize(); @@ -60,7 +62,9 @@ public BlockBuilder writeShort(int value) values[positionCount] = (short) value; positionCount++; - blockBuilderStatus.addBytes(Byte.BYTES + Short.BYTES); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Byte.BYTES + Short.BYTES); + } return this; } @@ -80,7 +84,9 @@ public BlockBuilder appendNull() valueIsNull[positionCount] = true; positionCount++; - blockBuilderStatus.addBytes(Byte.BYTES + Short.BYTES); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(Byte.BYTES + Short.BYTES); + } return this; } @@ -114,28 +120,39 @@ private void growCapacity() private void updateDataSize() { - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values)); + retainedSizeInBytes = INSTANCE_SIZE + sizeOf(valueIsNull) + sizeOf(values); + if (blockBuilderStatus != null) { + retainedSizeInBytes += BlockBuilderStatus.INSTANCE_SIZE; + } } // Copied from ShortArrayBlock @Override - public int getSizeInBytes() + public long getSizeInBytes() { - return intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) positionCount); + return (Short.BYTES + Byte.BYTES) * positionCount; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast((Short.BYTES + Byte.BYTES) * (long) length); + return (Short.BYTES + Byte.BYTES) * length; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, sizeOf(values)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public int getPositionCount() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayElementBlockWriter.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleArrayBlockWriter.java similarity index 82% rename from presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayElementBlockWriter.java rename to presto-spi/src/main/java/com/facebook/presto/spi/block/SingleArrayBlockWriter.java index 180b1a46739af..e3a9fa6dcd372 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ArrayElementBlockWriter.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleArrayBlockWriter.java @@ -16,17 +16,19 @@ import io.airlift.slice.Slice; import org.openjdk.jol.info.ClassLayout; -public class ArrayElementBlockWriter - extends AbstractArrayElementBlock +import java.util.function.BiConsumer; + +public class SingleArrayBlockWriter + extends AbstractSingleArrayBlock implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(ArrayElementBlockWriter.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleArrayBlockWriter.class).instanceSize(); private final BlockBuilder blockBuilder; - private final int initialBlockBuilderSize; + private final long initialBlockBuilderSize; private int positionsWritten; - public ArrayElementBlockWriter(BlockBuilder blockBuilder, int start) + public SingleArrayBlockWriter(BlockBuilder blockBuilder, int start) { super(start); this.blockBuilder = blockBuilder; @@ -40,17 +42,24 @@ protected BlockBuilder getBlock() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return blockBuilder.getSizeInBytes() - initialBlockBuilderSize; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return INSTANCE_SIZE + blockBuilder.getRetainedSizeInBytes(); } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(blockBuilder, blockBuilder.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public BlockBuilder writeByte(int value) { @@ -141,7 +150,7 @@ public BlockBuilder newBlockBuilderLike(BlockBuilderStatus blockBuilderStatus) @Override public String toString() { - StringBuilder sb = new StringBuilder("ArrayElementBlockWriter{"); + StringBuilder sb = new StringBuilder("SingleArrayBlockWriter{"); sb.append("positionCount=").append(getPositionCount()); sb.append('}'); return sb.toString(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java new file mode 100644 index 0000000000000..16187c23b24aa --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlock.java @@ -0,0 +1,360 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.type.Type; +import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; + +import java.lang.invoke.MethodHandle; +import java.util.function.BiConsumer; + +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static com.facebook.presto.spi.block.AbstractMapBlock.HASH_MULTIPLIER; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.airlift.slice.SizeOf.sizeOfIntArray; + +public class SingleMapBlock + extends AbstractSingleMapBlock +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMapBlockWriter.class).instanceSize(); + + private final int offset; + private final int positionCount; + private final Block keyBlock; + private final Block valueBlock; + private final int[] hashTable; + private final Type keyType; + private final MethodHandle keyNativeHashCode; + private final MethodHandle keyBlockNativeEquals; + + SingleMapBlock(int offset, int positionCount, Block keyBlock, Block valueBlock, int[] hashTable, Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals) + { + super(offset, keyBlock, valueBlock); + + this.offset = offset; + this.positionCount = positionCount; + this.keyBlock = keyBlock; + this.valueBlock = valueBlock; + this.hashTable = hashTable; + this.keyType = keyType; + this.keyNativeHashCode = keyNativeHashCode; + this.keyBlockNativeEquals = keyBlockNativeEquals; + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public long getSizeInBytes() + { + return keyBlock.getRegionSizeInBytes(offset / 2, positionCount / 2) + + valueBlock.getRegionSizeInBytes(offset / 2, positionCount / 2) + + sizeOfIntArray(positionCount / 2 * HASH_MULTIPLIER); + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + keyBlock.getRetainedSizeInBytes() + valueBlock.getRetainedSizeInBytes() + sizeOf(hashTable); + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlock, keyBlock.getRetainedSizeInBytes()); + consumer.accept(valueBlock, valueBlock.getRetainedSizeInBytes()); + consumer.accept(hashTable, sizeOf(hashTable)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + + @Override + public BlockEncoding getEncoding() + { + return new SingleMapBlockEncoding(keyType, keyNativeHashCode, keyBlockNativeEquals, keyBlock.getEncoding(), valueBlock.getEncoding()); + } + + public int getOffset() + { + return offset; + } + + Block getKeyBlock() + { + return keyBlock; + } + + Block getValueBlock() + { + return valueBlock; + } + + int[] getHashTable() + { + return hashTable; + } + + public int seekKey(Object nativeValue) + { + if (positionCount == 0) { + return -1; + } + + long hashCode; + try { + hashCode = (long) keyNativeHashCode.invoke(nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + + int hashTableOffset = offset / 2 * HASH_MULTIPLIER; + int hashTableSize = positionCount / 2 * HASH_MULTIPLIER; + int hash = (int) Math.floorMod(hashCode, hashTableSize); + while (true) { + int keyPosition = hashTable[hashTableOffset + hash]; + if (keyPosition == -1) { + return -1; + } + boolean match; + try { + match = (boolean) keyBlockNativeEquals.invoke(keyBlock, offset / 2 + keyPosition, nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + if (match) { + return keyPosition * 2 + 1; + } + hash++; + if (hash == hashTableSize) { + hash = 0; + } + } + } + + // The next 5 seekKeyExact functions are the same as seekKey + // except MethodHandle.invoke is replaced with invokeExact. + + public int seekKeyExact(long nativeValue) + { + if (positionCount == 0) { + return -1; + } + + long hashCode; + try { + hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + + int hashTableOffset = offset / 2 * HASH_MULTIPLIER; + int hashTableSize = positionCount / 2 * HASH_MULTIPLIER; + int hash = (int) Math.floorMod(hashCode, hashTableSize); + while (true) { + int keyPosition = hashTable[hashTableOffset + hash]; + if (keyPosition == -1) { + return -1; + } + boolean match; + try { + match = (boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + if (match) { + return keyPosition * 2 + 1; + } + hash++; + if (hash == hashTableSize) { + hash = 0; + } + } + } + + public int seekKeyExact(boolean nativeValue) + { + if (positionCount == 0) { + return -1; + } + + long hashCode; + try { + hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + + int hashTableOffset = offset / 2 * HASH_MULTIPLIER; + int hashTableSize = positionCount / 2 * HASH_MULTIPLIER; + int hash = (int) Math.floorMod(hashCode, hashTableSize); + while (true) { + int keyPosition = hashTable[hashTableOffset + hash]; + if (keyPosition == -1) { + return -1; + } + boolean match; + try { + match = (boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + if (match) { + return keyPosition * 2 + 1; + } + hash++; + if (hash == hashTableSize) { + hash = 0; + } + } + } + + public int seekKeyExact(double nativeValue) + { + if (positionCount == 0) { + return -1; + } + + long hashCode; + try { + hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + + int hashTableOffset = offset / 2 * HASH_MULTIPLIER; + int hashTableSize = positionCount / 2 * HASH_MULTIPLIER; + int hash = (int) Math.floorMod(hashCode, hashTableSize); + while (true) { + int keyPosition = hashTable[hashTableOffset + hash]; + if (keyPosition == -1) { + return -1; + } + boolean match; + try { + match = (boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + if (match) { + return keyPosition * 2 + 1; + } + hash++; + if (hash == hashTableSize) { + hash = 0; + } + } + } + + public int seekKeyExact(Slice nativeValue) + { + if (positionCount == 0) { + return -1; + } + + long hashCode; + try { + hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + + int hashTableOffset = offset / 2 * HASH_MULTIPLIER; + int hashTableSize = positionCount / 2 * HASH_MULTIPLIER; + int hash = (int) Math.floorMod(hashCode, hashTableSize); + while (true) { + int keyPosition = hashTable[hashTableOffset + hash]; + if (keyPosition == -1) { + return -1; + } + boolean match; + try { + match = (boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + if (match) { + return keyPosition * 2 + 1; + } + hash++; + if (hash == hashTableSize) { + hash = 0; + } + } + } + + public int seekKeyExact(Block nativeValue) + { + if (positionCount == 0) { + return -1; + } + + long hashCode; + try { + hashCode = (long) keyNativeHashCode.invokeExact(nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + + int hashTableOffset = offset / 2 * HASH_MULTIPLIER; + int hashTableSize = positionCount / 2 * HASH_MULTIPLIER; + int hash = (int) Math.floorMod(hashCode, hashTableSize); + while (true) { + int keyPosition = hashTable[hashTableOffset + hash]; + if (keyPosition == -1) { + return -1; + } + boolean match; + try { + match = (boolean) keyBlockNativeEquals.invokeExact(keyBlock, offset / 2 + keyPosition, nativeValue); + } + catch (Throwable throwable) { + throw handleThrowable(throwable); + } + if (match) { + return keyPosition * 2 + 1; + } + hash++; + if (hash == hashTableSize) { + hash = 0; + } + } + } + + private static RuntimeException handleThrowable(Throwable throwable) + { + if (throwable instanceof Error) { + throw (Error) throwable; + } + if (throwable instanceof PrestoException) { + throw (PrestoException) throwable; + } + throw new PrestoException(GENERIC_INTERNAL_ERROR, throwable); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java new file mode 100644 index 0000000000000..9c5ae989953c8 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockEncoding.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSerde; +import io.airlift.slice.SliceInput; +import io.airlift.slice.SliceOutput; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.spi.block.AbstractMapBlock.HASH_MULTIPLIER; +import static com.facebook.presto.spi.block.MethodHandleUtil.compose; +import static com.facebook.presto.spi.block.MethodHandleUtil.nativeValueGetter; +import static io.airlift.slice.Slices.wrappedIntArray; +import static java.lang.String.format; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; + +public class SingleMapBlockEncoding + implements BlockEncoding +{ + public static final BlockEncodingFactory FACTORY = new SingleMapBlockEncodingFactory(); + private static final String NAME = "MAP_ELEMENT"; + + private final Type keyType; + private final MethodHandle keyNativeHashCode; + private final MethodHandle keyBlockNativeEquals; + private final BlockEncoding keyBlockEncoding; + private final BlockEncoding valueBlockEncoding; + + public SingleMapBlockEncoding(Type keyType, MethodHandle keyNativeHashCode, MethodHandle keyBlockNativeEquals, BlockEncoding keyBlockEncoding, BlockEncoding valueBlockEncoding) + { + this.keyType = requireNonNull(keyType, "keyType is null"); + // keyNativeHashCode can only be null due to map block kill switch. deprecated.new-map-block + this.keyNativeHashCode = keyNativeHashCode; + // keyBlockNativeEquals can only be null due to map block kill switch. deprecated.new-map-block + this.keyBlockNativeEquals = keyBlockNativeEquals; + this.keyBlockEncoding = requireNonNull(keyBlockEncoding, "keyBlockEncoding is null"); + this.valueBlockEncoding = requireNonNull(valueBlockEncoding, "valueBlockEncoding is null"); + } + + @Override + public String getName() + { + return NAME; + } + + @Override + public void writeBlock(SliceOutput sliceOutput, Block block) + { + SingleMapBlock singleMapBlock = (SingleMapBlock) block; + int offset = singleMapBlock.getOffset(); + int positionCount = singleMapBlock.getPositionCount(); + keyBlockEncoding.writeBlock(sliceOutput, singleMapBlock.getKeyBlock().getRegion(offset / 2, positionCount / 2)); + valueBlockEncoding.writeBlock(sliceOutput, singleMapBlock.getValueBlock().getRegion(offset / 2, positionCount / 2)); + int[] hashTable = singleMapBlock.getHashTable(); + sliceOutput.appendInt(positionCount / 2 * HASH_MULTIPLIER); + sliceOutput.writeBytes(wrappedIntArray(hashTable, offset / 2 * HASH_MULTIPLIER, positionCount / 2 * HASH_MULTIPLIER)); + } + + @Override + public Block readBlock(SliceInput sliceInput) + { + Block keyBlock = keyBlockEncoding.readBlock(sliceInput); + Block valueBlock = valueBlockEncoding.readBlock(sliceInput); + + int[] hashTable = new int[sliceInput.readInt()]; + sliceInput.readBytes(wrappedIntArray(hashTable)); + + if (keyBlock.getPositionCount() != valueBlock.getPositionCount() || keyBlock.getPositionCount() * HASH_MULTIPLIER != hashTable.length) { + throw new IllegalArgumentException( + format("Deserialized SingleMapBlock violates invariants: key %d, value %d, hash %d", keyBlock.getPositionCount(), valueBlock.getPositionCount(), hashTable.length)); + } + + return new SingleMapBlock(0, keyBlock.getPositionCount() * 2, keyBlock, valueBlock, hashTable, keyType, keyNativeHashCode, keyBlockNativeEquals); + } + + @Override + public BlockEncodingFactory getFactory() + { + return FACTORY; + } + + public static class SingleMapBlockEncodingFactory + implements BlockEncodingFactory + { + @Override + public String getName() + { + return NAME; + } + + @Override + public SingleMapBlockEncoding readEncoding(TypeManager typeManager, BlockEncodingSerde serde, SliceInput input) + { + Type keyType = TypeSerde.readType(typeManager, input); + MethodHandle keyNativeHashCode = typeManager.resolveOperator(OperatorType.HASH_CODE, singletonList(keyType)); + MethodHandle keyNativeEquals = typeManager.resolveOperator(OperatorType.EQUAL, asList(keyType, keyType)); + MethodHandle keyBlockNativeEquals = compose(keyNativeEquals, nativeValueGetter(keyType)); + + BlockEncoding keyBlockEncoding = serde.readBlockEncoding(input); + BlockEncoding valueBlockEncoding = serde.readBlockEncoding(input); + return new SingleMapBlockEncoding(keyType, keyNativeHashCode, keyBlockNativeEquals, keyBlockEncoding, valueBlockEncoding); + } + + @Override + public void writeEncoding(BlockEncodingSerde serde, SliceOutput output, SingleMapBlockEncoding blockEncoding) + { + TypeSerde.writeType(output, blockEncoding.keyType); + serde.writeBlockEncoding(output, blockEncoding.keyBlockEncoding); + serde.writeBlockEncoding(output, blockEncoding.valueBlockEncoding); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockWriter.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockWriter.java new file mode 100644 index 0000000000000..840c068ca3e2d --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SingleMapBlockWriter.java @@ -0,0 +1,211 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.block; + +import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; + +import java.util.function.BiConsumer; + +public class SingleMapBlockWriter + extends AbstractSingleMapBlock + implements BlockBuilder +{ + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleMapBlockWriter.class).instanceSize(); + + private final BlockBuilder keyBlockBuilder; + private final BlockBuilder valueBlockBuilder; + private final long initialBlockBuilderSize; + private int positionsWritten; + + private boolean writeToValueNext; + + SingleMapBlockWriter(int start, BlockBuilder keyBlockBuilder, BlockBuilder valueBlockBuilder) + { + super(start, keyBlockBuilder, valueBlockBuilder); + this.keyBlockBuilder = keyBlockBuilder; + this.valueBlockBuilder = valueBlockBuilder; + this.initialBlockBuilderSize = keyBlockBuilder.getSizeInBytes() + valueBlockBuilder.getSizeInBytes(); + } + + @Override + public long getSizeInBytes() + { + return keyBlockBuilder.getSizeInBytes() + valueBlockBuilder.getSizeInBytes() - initialBlockBuilderSize; + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + keyBlockBuilder.getRetainedSizeInBytes() + valueBlockBuilder.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(keyBlockBuilder, keyBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(valueBlockBuilder, valueBlockBuilder.getRetainedSizeInBytes()); + consumer.accept(this, (long) INSTANCE_SIZE); + } + + @Override + public BlockBuilder writeByte(int value) + { + if (writeToValueNext) { + valueBlockBuilder.writeByte(value); + } + else { + keyBlockBuilder.writeByte(value); + } + return this; + } + + @Override + public BlockBuilder writeShort(int value) + { + if (writeToValueNext) { + valueBlockBuilder.writeShort(value); + } + else { + keyBlockBuilder.writeShort(value); + } + return this; + } + + @Override + public BlockBuilder writeInt(int value) + { + if (writeToValueNext) { + valueBlockBuilder.writeInt(value); + } + else { + keyBlockBuilder.writeInt(value); + } + return this; + } + + @Override + public BlockBuilder writeLong(long value) + { + if (writeToValueNext) { + valueBlockBuilder.writeLong(value); + } + else { + keyBlockBuilder.writeLong(value); + } + return this; + } + + @Override + public BlockBuilder writeBytes(Slice source, int sourceIndex, int length) + { + if (writeToValueNext) { + valueBlockBuilder.writeBytes(source, sourceIndex, length); + } + else { + keyBlockBuilder.writeBytes(source, sourceIndex, length); + } + return this; + } + + @Override + public BlockBuilder writeObject(Object value) + { + if (writeToValueNext) { + valueBlockBuilder.writeObject(value); + } + else { + keyBlockBuilder.writeObject(value); + } + return this; + } + + @Override + public BlockBuilder beginBlockEntry() + { + BlockBuilder result; + if (writeToValueNext) { + result = valueBlockBuilder.beginBlockEntry(); + } + else { + result = keyBlockBuilder.beginBlockEntry(); + } + return result; + } + + @Override + public BlockBuilder appendNull() + { + if (writeToValueNext) { + valueBlockBuilder.appendNull(); + } + else { + keyBlockBuilder.appendNull(); + } + entryAdded(); + return this; + } + + @Override + public BlockBuilder closeEntry() + { + if (writeToValueNext) { + valueBlockBuilder.closeEntry(); + } + else { + keyBlockBuilder.closeEntry(); + } + entryAdded(); + return this; + } + + private void entryAdded() + { + writeToValueNext = !writeToValueNext; + positionsWritten++; + } + + @Override + public int getPositionCount() + { + return positionsWritten; + } + + @Override + public BlockEncoding getEncoding() + { + throw new UnsupportedOperationException(); + } + + @Override + public Block build() + { + throw new UnsupportedOperationException(); + } + + @Override + public BlockBuilder newBlockBuilderLike(BlockBuilderStatus blockBuilderStatus) + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() + { + StringBuilder sb = new StringBuilder("SingleMapBlockWriter{"); + sb.append("positionCount=").append(getPositionCount()); + sb.append('}'); + return sb.toString(); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java index 9a9198ef17e76..913511a8126b0 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/SliceArrayBlock.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.spi.block; -import io.airlift.slice.SizeOf; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import org.openjdk.jol.info.ClassLayout; @@ -22,21 +21,19 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; -import static sun.misc.Unsafe.ARRAY_OBJECT_INDEX_SCALE; +import static io.airlift.slice.SizeOf.sizeOf; public class SliceArrayBlock extends AbstractVariableWidthBlock { private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceArrayBlock.class).instanceSize(); - private static final int SLICE_INSTANCE_SIZE = ClassLayout.parseClass(Slice.class).instanceSize(); - private final int positionCount; private final Slice[] values; - private final int sizeInBytes; - private final int retainedSizeInBytes; + private final long sizeInBytes; + private final long retainedSizeInBytes; public SliceArrayBlock(int positionCount, Slice[] values) { @@ -52,10 +49,14 @@ public SliceArrayBlock(int positionCount, Slice[] values, boolean valueSlicesAre } this.values = values; - sizeInBytes = getSliceArraySizeInBytes(values); - - // if values are distinct, use the already computed value - retainedSizeInBytes = INSTANCE_SIZE + (valueSlicesAreDistinct ? sizeInBytes : getSliceArrayRetainedSizeInBytes(values)); + sizeInBytes = getSliceArraySizeInBytes(values, 0, values.length); + // We use IdentityHashMap for reference counting below to do proper memory accounting. + // Since IdentityHashMap uses linear probing, depending on the load factor threads going through the + // retained size calculation method will make multiple probes in the hash table, which will consume some cycles. + // We see that many threads spend cycles probing the hash table when they read the dictionary data and + // wrap that in a SliceArrayBlock (in SliceDictionaryStreamReader). Therefore, we avoid going through + // the IdentityHashMap for that code path with the valueSlicesAreDistinct flag. + retainedSizeInBytes = INSTANCE_SIZE + getSliceArrayRetainedSizeInBytes(values, valueSlicesAreDistinct); } public Slice[] getValues() @@ -114,17 +115,24 @@ public int getSliceLength(int position) } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(values, retainedSizeInBytes - INSTANCE_SIZE); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public Block getRegion(int positionOffset, int length) { @@ -177,19 +185,20 @@ public String toString() return sb.toString(); } - public static int getSliceArraySizeInBytes(Slice[] values) + private static long getSliceArraySizeInBytes(Slice[] values, int offset, int length) { - long sizeInBytes = values.length * (long) (ARRAY_OBJECT_INDEX_SCALE + SLICE_INSTANCE_SIZE); - for (Slice value : values) { + long sizeInBytes = 0; + for (int i = offset; i < offset + length; i++) { + Slice value = values[i]; if (value != null) { sizeInBytes += value.length(); } } - return intSaturatedCast(sizeInBytes); + return sizeInBytes; } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { int positionCount = getPositionCount(); if (positionOffset == 0 && length == positionCount) { @@ -201,25 +210,43 @@ public int getRegionSizeInBytes(int positionOffset, int length) throw new IndexOutOfBoundsException("Invalid position " + positionOffset + " in block with " + positionCount + " positions"); } - long sizeInBytes = length * (long) (ARRAY_OBJECT_INDEX_SCALE + SLICE_INSTANCE_SIZE); - for (int i = positionOffset; i < positionOffset + length; i++) { - Slice value = values[i]; - if (value != null) { - sizeInBytes += value.length(); - } + return getSliceArraySizeInBytes(values, positionOffset, length); + } + + private static long getSliceArrayRetainedSizeInBytes(Slice[] values, boolean valueSlicesAreDistinct) + { + if (valueSlicesAreDistinct) { + return getDistinctSliceArrayRetainedSize(values); } - return intSaturatedCast(sizeInBytes); + return getSliceArrayRetainedSizeInBytes(values); } - static int getSliceArrayRetainedSizeInBytes(Slice[] values) + // when the slices are not distinct we need to do reference counting to calculate the total retained size + private static long getSliceArrayRetainedSizeInBytes(Slice[] values) { - long sizeInBytes = SizeOf.sizeOf(values); + long sizeInBytes = sizeOf(values); Map uniqueRetained = new IdentityHashMap<>(values.length); for (Slice value : values) { - if (value != null && value.getBase() != null && uniqueRetained.put(value.getBase(), true) == null) { + if (value == null) { + continue; + } + if (value.getBase() != null && uniqueRetained.put(value.getBase(), true) == null) { sizeInBytes += value.getRetainedSize(); } } - return intSaturatedCast(sizeInBytes); + return sizeInBytes; + } + + private static long getDistinctSliceArrayRetainedSize(Slice[] values) + { + long sizeInBytes = sizeOf(values); + for (Slice value : values) { + // The check (value.getBase() == null) skips empty slices to be consistent with getSliceArrayRetainedSizeInBytes(values) + if (value == null || value.getBase() == null) { + continue; + } + sizeInBytes += value.getRetainedSize(); + } + return sizeInBytes; } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java index 01ef6db99b6ea..1b6bf8a4e11c4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlock.java @@ -20,10 +20,10 @@ import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.checkValidPositions; import static com.facebook.presto.spi.block.BlockUtil.checkValidRegion; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.sizeOf; public class VariableWidthBlock @@ -37,8 +37,8 @@ public class VariableWidthBlock private final int[] offsets; private final boolean[] valueIsNull; - private final int retainedSizeInBytes; - private final int sizeInBytes; + private final long retainedSizeInBytes; + private final long sizeInBytes; public VariableWidthBlock(int positionCount, Slice slice, int[] offsets, boolean[] valueIsNull) { @@ -71,8 +71,8 @@ public VariableWidthBlock(int positionCount, Slice slice, int[] offsets, boolean } this.valueIsNull = valueIsNull; - sizeInBytes = intSaturatedCast(offsets[arrayOffset + positionCount] - offsets[arrayOffset] + ((Integer.BYTES + Byte.BYTES) * (long) positionCount)); - retainedSizeInBytes = intSaturatedCast(INSTANCE_SIZE + slice.getRetainedSize() + sizeOf(valueIsNull) + sizeOf(offsets)); + sizeInBytes = offsets[arrayOffset + positionCount] - offsets[arrayOffset] + ((Integer.BYTES + Byte.BYTES) * (long) positionCount); + retainedSizeInBytes = INSTANCE_SIZE + slice.getRetainedSize() + sizeOf(valueIsNull) + sizeOf(offsets); } @Override @@ -101,23 +101,32 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { return sizeInBytes; } @Override - public int getRegionSizeInBytes(int position, int length) + public long getRegionSizeInBytes(int position, int length) { - return intSaturatedCast(offsets[arrayOffset + position + length] - offsets[arrayOffset + position] + ((Integer.BYTES + Byte.BYTES) * (long) length)); + return offsets[arrayOffset + position + length] - offsets[arrayOffset + position] + ((Integer.BYTES + Byte.BYTES) * (long) length); } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { return retainedSizeInBytes; } + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(slice, (long) slice.getRetainedSize()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); + } + @Override public Block copyPositions(List positions) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java index dad8956304c0a..27fdea423d0ab 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/VariableWidthBlockBuilder.java @@ -19,24 +19,25 @@ import io.airlift.slice.Slices; import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; import java.util.List; +import java.util.function.BiConsumer; import static com.facebook.presto.spi.block.BlockUtil.MAX_ARRAY_SIZE; import static com.facebook.presto.spi.block.BlockUtil.calculateBlockResetSize; -import static com.facebook.presto.spi.block.BlockUtil.intSaturatedCast; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.airlift.slice.SizeOf.SIZE_OF_INT; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.slice.SizeOf.SIZE_OF_SHORT; import static io.airlift.slice.SizeOf.sizeOf; -import static java.util.Objects.requireNonNull; public class VariableWidthBlockBuilder extends AbstractVariableWidthBlock implements BlockBuilder { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(VariableWidthBlockBuilder.class).instanceSize() + BlockBuilderStatus.INSTANCE_SIZE; + private static final int INSTANCE_SIZE = ClassLayout.parseClass(VariableWidthBlockBuilder.class).instanceSize(); private BlockBuilderStatus blockBuilderStatus; @@ -55,9 +56,9 @@ public class VariableWidthBlockBuilder private long arraysRetainedSizeInBytes; - public VariableWidthBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + public VariableWidthBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { - this.blockBuilderStatus = requireNonNull(blockBuilderStatus, "blockBuilderStatus is null"); + this.blockBuilderStatus = blockBuilderStatus; initialEntryCount = expectedEntries; initialSliceOutputSize = (int) Math.min((long) expectedBytesPerEntry * expectedEntries, MAX_ARRAY_SIZE); @@ -96,27 +97,40 @@ public int getPositionCount() } @Override - public int getSizeInBytes() + public long getSizeInBytes() { - long arraysSizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) positions; - return intSaturatedCast(sliceOutput.size() + arraysSizeInBytes); + long arraysSizeInBytes = (Integer.BYTES + Byte.BYTES) * positions; + return sliceOutput.size() + arraysSizeInBytes; } @Override - public int getRegionSizeInBytes(int positionOffset, int length) + public long getRegionSizeInBytes(int positionOffset, int length) { int positionCount = getPositionCount(); if (positionOffset < 0 || length < 0 || positionOffset + length > positionCount) { throw new IndexOutOfBoundsException("Invalid position " + positionOffset + " length " + length + " in block with " + positionCount + " positions"); } - long arraysSizeInBytes = (Integer.BYTES + Byte.BYTES) * (long) length; - return intSaturatedCast(getOffset(positionOffset + length) - getOffset(positionOffset) + arraysSizeInBytes); + long arraysSizeInBytes = (Integer.BYTES + Byte.BYTES) * length; + return getOffset(positionOffset + length) - getOffset(positionOffset) + arraysSizeInBytes; } @Override - public int getRetainedSizeInBytes() + public long getRetainedSizeInBytes() { - return intSaturatedCast(INSTANCE_SIZE + sliceOutput.getRetainedSize() + arraysRetainedSizeInBytes); + long size = INSTANCE_SIZE + sliceOutput.getRetainedSize() + arraysRetainedSizeInBytes; + if (blockBuilderStatus != null) { + size += BlockBuilderStatus.INSTANCE_SIZE; + } + return size; + } + + @Override + public void retainedBytesForEachPart(BiConsumer consumer) + { + consumer.accept(sliceOutput, (long) sliceOutput.getRetainedSize()); + consumer.accept(offsets, sizeOf(offsets)); + consumer.accept(valueIsNull, sizeOf(valueIsNull)); + consumer.accept(this, (long) INSTANCE_SIZE); } @Override @@ -228,7 +242,9 @@ private void entryAdded(int bytesWritten, boolean isNull) positions++; - blockBuilderStatus.addBytes(SIZE_OF_BYTE + SIZE_OF_INT + bytesWritten); + if (blockBuilderStatus != null) { + blockBuilderStatus.addBytes(SIZE_OF_BYTE + SIZE_OF_INT + bytesWritten); + } } private void growCapacity() @@ -253,7 +269,7 @@ private void initializeCapacity() private void updateArraysDataSize() { - arraysRetainedSizeInBytes = intSaturatedCast(sizeOf(valueIsNull) + sizeOf(offsets)); + arraysRetainedSizeInBytes = sizeOf(valueIsNull) + sizeOf(offsets); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java index 539d023c84838..4f572f536b4af 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorAccessControl.java @@ -262,7 +262,7 @@ default void checkCanSetCatalogSessionProperty(Identity identity, String propert * * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed */ - default void checkCanGrantTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName) + default void checkCanGrantTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName, String grantee, boolean withGrantOption) { denyGrantTablePrivilege(privilege.toString(), tableName.toString()); } @@ -272,7 +272,7 @@ default void checkCanGrantTablePrivilege(ConnectorTransactionHandle transactionH * * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed */ - default void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName) + default void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transactionHandle, Identity identity, Privilege privilege, SchemaTableName tableName, String revokee, boolean grantOptionFor) { denyRevokeTablePrivilege(privilege.toString(), tableName.toString()); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java index f9c5774263555..56ce4b5a0ecdd 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java @@ -35,6 +35,7 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.statistics.TableStatistics; import io.airlift.slice.Slice; import java.util.Collection; @@ -47,6 +48,7 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.statistics.TableStatistics.EMPTY_STATISTICS; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.stream.Collectors.toList; @@ -126,6 +128,14 @@ default Optional getInfo(ConnectorTableLayoutHandle layoutHandle) */ Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix); + /** + * Get statistics for table for given filtering constraint. + */ + default TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint) + { + return EMPTY_STATISTICS; + } + /** * Creates a schema. */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index 37b87d4848445..974f1c9c644bd 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -37,6 +37,7 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.statistics.TableStatistics; import io.airlift.slice.Slice; import java.util.Collection; @@ -168,6 +169,14 @@ public Map> listTableColumns(ConnectorSess } } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getTableStatistics(session, tableHandle, constraint); + } + } + @Override public void addColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnMetadata column) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryContext.java index e07fce2baa69e..08602900efa18 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryContext.java @@ -32,6 +32,8 @@ public class QueryContext private final Optional catalog; private final Optional schema; + private final Optional resourceGroupName; + private final Map sessionProperties; private final String serverAddress; @@ -47,6 +49,7 @@ public QueryContext( Optional source, Optional catalog, Optional schema, + Optional resourceGroupName, Map sessionProperties, String serverAddress, String serverVersion, @@ -60,6 +63,7 @@ public QueryContext( this.source = requireNonNull(source, "source is null"); this.catalog = requireNonNull(catalog, "catalog is null"); this.schema = requireNonNull(schema, "schema is null"); + this.resourceGroupName = requireNonNull(resourceGroupName, "resourceGroupName is null"); this.sessionProperties = requireNonNull(sessionProperties, "sessionProperties is null"); this.serverAddress = requireNonNull(serverAddress, "serverAddress is null"); this.serverVersion = requireNonNull(serverVersion, "serverVersion is null"); @@ -114,6 +118,12 @@ public Optional getSchema() return schema; } + @JsonProperty + public Optional getResourceGroupName() + { + return resourceGroupName; + } + @JsonProperty public Map getSessionProperties() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryStatistics.java index d9e34a0535be9..2d4b6b35a591f 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryStatistics.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/QueryStatistics.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.eventlistener; import java.time.Duration; +import java.util.List; import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -30,10 +31,14 @@ public class QueryStatistics private final long totalBytes; private final long totalRows; + private final double cumulativeMemory; + private final int completedSplits; private final boolean complete; - private final String operatorSummaries; + private final List cpuTimeDistribution; + + private final List operatorSummaries; public QueryStatistics( Duration cpuTime, @@ -44,9 +49,11 @@ public QueryStatistics( long peakMemoryBytes, long totalBytes, long totalRows, + double cumulativeMemory, int completedSplits, boolean complete, - String operatorSummaries) + List cpuTimeDistribution, + List operatorSummaries) { this.cpuTime = requireNonNull(cpuTime, "cpuTime is null"); this.wallTime = requireNonNull(wallTime, "wallTime is null"); @@ -56,8 +63,10 @@ public QueryStatistics( this.peakMemoryBytes = requireNonNull(peakMemoryBytes, "peakMemoryBytes is null"); this.totalBytes = requireNonNull(totalBytes, "totalBytes is null"); this.totalRows = requireNonNull(totalRows, "totalRows is null"); + this.cumulativeMemory = cumulativeMemory; this.completedSplits = requireNonNull(completedSplits, "completedSplits is null"); this.complete = complete; + this.cpuTimeDistribution = requireNonNull(cpuTimeDistribution, "cpuTimeDistribution is null"); this.operatorSummaries = requireNonNull(operatorSummaries, "operatorSummaries is null"); } @@ -101,6 +110,11 @@ public long getTotalRows() return totalRows; } + public double getCumulativeMemory() + { + return cumulativeMemory; + } + public int getCompletedSplits() { return completedSplits; @@ -111,7 +125,12 @@ public boolean isComplete() return complete; } - public String getOperatorSummaries() + public List getCpuTimeDistribution() + { + return cpuTimeDistribution; + } + + public List getOperatorSummaries() { return operatorSummaries; } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/StageCpuDistribution.java b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/StageCpuDistribution.java new file mode 100644 index 0000000000000..f3c20a84f7d79 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/eventlistener/StageCpuDistribution.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.eventlistener; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +public class StageCpuDistribution +{ + private final int stageId; + private final int tasks; + private final long p25; + private final long p50; + private final long p75; + private final long p90; + private final long p95; + private final long p99; + private final long min; + private final long max; + private final long total; + private final double average; + + @JsonCreator + public StageCpuDistribution( + @JsonProperty("stageId") int stageId, + @JsonProperty("tasks") int tasks, + @JsonProperty("p25") long p25, + @JsonProperty("p50") long p50, + @JsonProperty("p75") long p75, + @JsonProperty("p90") long p90, + @JsonProperty("p95") long p95, + @JsonProperty("p99") long p99, + @JsonProperty("min") long min, + @JsonProperty("max") long max, + @JsonProperty("total") long total, + @JsonProperty("average") double average) + { + this.stageId = stageId; + this.tasks = tasks; + this.p25 = p25; + this.p50 = p50; + this.p75 = p75; + this.p90 = p90; + this.p95 = p95; + this.p99 = p99; + this.min = min; + this.max = max; + this.total = total; + this.average = average; + } + + @JsonProperty + public int getStageId() + { + return stageId; + } + + @JsonProperty + public int getTasks() + { + return tasks; + } + + @JsonProperty + public long getP25() + { + return p25; + } + + @JsonProperty + public long getP50() + { + return p50; + } + + @JsonProperty + public long getP75() + { + return p75; + } + + @JsonProperty + public long getP90() + { + return p90; + } + + @JsonProperty + public long getP95() + { + return p95; + } + + @JsonProperty + public long getP99() + { + return p99; + } + + @JsonProperty + public long getMin() + { + return min; + } + + @JsonProperty + public long getMax() + { + return max; + } + + @JsonProperty + public long getTotal() + { + return total; + } + + @JsonProperty + public double getAverage() + { + return average; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroup.java b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroup.java index 13d35d7226950..86a851a92567c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroup.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroup.java @@ -86,4 +86,12 @@ public interface ResourceGroup * Whether to export statistics about this group and allow configuration via JMX. */ void setJmxExport(boolean export); + + Duration getQueuedTimeLimit(); + + void setQueuedTimeLimit(Duration queuedTimeLimit); + + Duration getRunningTimeLimit(); + + void setRunningTimeLimit(Duration runningTimeLimit); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupConfigurationManagerContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupConfigurationManagerContext.java index 8a64e02c0a62a..6ddde1677e00e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupConfigurationManagerContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupConfigurationManagerContext.java @@ -17,5 +17,13 @@ public interface ResourceGroupConfigurationManagerContext { - ClusterMemoryPoolManager getMemoryPoolManager(); + default ClusterMemoryPoolManager getMemoryPoolManager() + { + throw new UnsupportedOperationException(); + } + + default String getEnvironment() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupId.java b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupId.java index 0f97bd00217b7..da1500ccda99c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupId.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupId.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.spi.resourceGroups; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; import java.util.ArrayList; @@ -46,7 +47,8 @@ private static List append(List list, String element) return result; } - private ResourceGroupId(List segments) + @JsonCreator + public ResourceGroupId(List segments) { checkArgument(!segments.isEmpty(), "Resource group id is empty"); for (String segment : segments) { @@ -60,6 +62,7 @@ public String getLastSegment() return segments.get(segments.size() - 1); } + @JsonValue public List getSegments() { return segments; @@ -95,7 +98,6 @@ private static void checkArgument(boolean argument, String format, Object... arg } @Override - @JsonValue public String toString() { return segments.stream() diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java index abe8e6f245c55..e30d5393ada7e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/resourceGroups/ResourceGroupInfo.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import io.airlift.units.DataSize; +import io.airlift.units.Duration; import java.util.List; import java.util.Optional; @@ -30,7 +31,9 @@ public class ResourceGroupInfo private final DataSize softMemoryLimit; private final int maxRunningQueries; + private final Duration runningTimeLimit; private final int maxQueuedQueries; + private final Duration queuedTimeLimit; private final ResourceGroupState state; private final int numEligibleSubGroups; @@ -45,21 +48,25 @@ public ResourceGroupInfo( @JsonProperty("id") ResourceGroupId id, @JsonProperty("softMemoryLimit") DataSize softMemoryLimit, @JsonProperty("maxRunningQueries") int maxRunningQueries, + @JsonProperty("runningTimeLimit") Duration runningTimeLimit, @JsonProperty("maxQueuedQueries") int maxQueuedQueries, + @JsonProperty("queuedTimeLimit") Duration queuedTimeLimit, @JsonProperty("state") ResourceGroupState state, @JsonProperty("numEligibleSubGroups") int numEligibleSubGroups, @JsonProperty("memoryUsage") DataSize memoryUsage, @JsonProperty("numAggregatedRunningQueries") int numAggregatedRunningQueries, @JsonProperty("numAggregatedQueuedQueries") int numAggregatedQueuedQueries) { - this(id, softMemoryLimit, maxRunningQueries, maxQueuedQueries, state, numEligibleSubGroups, memoryUsage, numAggregatedRunningQueries, numAggregatedQueuedQueries, emptyList()); + this(id, softMemoryLimit, maxRunningQueries, runningTimeLimit, maxQueuedQueries, queuedTimeLimit, state, numEligibleSubGroups, memoryUsage, numAggregatedRunningQueries, numAggregatedQueuedQueries, emptyList()); } public ResourceGroupInfo( ResourceGroupId id, DataSize softMemoryLimit, int maxRunningQueries, + Duration runningTimeLimit, int maxQueuedQueries, + Duration queuedTimeLimit, ResourceGroupState state, int numEligibleSubGroups, DataSize memoryUsage, @@ -70,7 +77,9 @@ public ResourceGroupInfo( this.id = requireNonNull(id, "id is null"); this.softMemoryLimit = requireNonNull(softMemoryLimit, "softMemoryLimit is null"); this.maxRunningQueries = maxRunningQueries; + this.runningTimeLimit = runningTimeLimit; this.maxQueuedQueries = maxQueuedQueries; + this.queuedTimeLimit = queuedTimeLimit; this.state = requireNonNull(state, "state is null"); this.numEligibleSubGroups = numEligibleSubGroups; this.memoryUsage = requireNonNull(memoryUsage, "memoryUsage is null"); @@ -97,12 +106,24 @@ public int getMaxRunningQueries() return maxRunningQueries; } + @JsonProperty + public Duration getRunningTimeLimit() + { + return runningTimeLimit; + } + @JsonProperty public int getMaxQueuedQueries() { return maxQueuedQueries; } + @JsonProperty + public Duration getQueuedTimeLimit() + { + return queuedTimeLimit; + } + public List getSubGroups() { return subGroups; @@ -151,7 +172,9 @@ public ResourceGroupInfo createSingleNodeInfo() getId(), getSoftMemoryLimit(), getMaxRunningQueries(), + getRunningTimeLimit(), getMaxQueuedQueries(), + getQueuedTimeLimit(), getState(), getNumEligibleSubGroups(), getMemoryUsage(), diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java index d7faa5b37d790..81cb34dd356ba 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/AccessDeniedException.java @@ -38,6 +38,16 @@ public static void denySetUser(Principal principal, String userName, String extr throw new AccessDeniedException(format("Principal %s cannot become user %s%s", principal, userName, formatExtraInfo(extraInfo))); } + public static void denyCatalogAccess(String catalogName) + { + denyCatalogAccess(catalogName, null); + } + + public static void denyCatalogAccess(String catalogName, String extraInfo) + { + throw new AccessDeniedException(format("Cannot access catalog %s%s", catalogName, formatExtraInfo(extraInfo))); + } + public static void denyCreateSchema(String schemaName) { denyCreateSchema(schemaName, null); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java b/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java index 09133d99b103f..7226aa7c911d6 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/security/SystemAccessControl.java @@ -22,6 +22,7 @@ import java.util.Set; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddColumn; +import static com.facebook.presto.spi.security.AccessDeniedException.denyCatalogAccess; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; @@ -58,6 +59,16 @@ public interface SystemAccessControl */ void checkCanSetSystemSessionProperty(Identity identity, String propertyName); + /** + * Check if identity is allowed to access the specified catalog + * + * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed + */ + default void checkCanAccessCatalog(Identity identity, String catalogName) + { + denyCatalogAccess(catalogName); + } + /** * Filter the list of catalogs to those visible to the identity. */ @@ -281,21 +292,21 @@ default void checkCanSetCatalogSessionProperty(Identity identity, String catalog } /** - * Check if identity is allowed to grant to any other user the specified privilege on the specified table. + * Check if identity is allowed to grant the specified privilege to the grantee on the specified table. * * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed */ - default void checkCanGrantTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table) + default void checkCanGrantTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table, String grantee, boolean withGrantOption) { denyGrantTablePrivilege(privilege.toString(), table.toString()); } /** - * Check if identity is allowed to revoke the specified privilege on the specified table from any user. + * Check if identity is allowed to revoke the specified privilege on the specified table from the revokee. * * @throws com.facebook.presto.spi.security.AccessDeniedException if not allowed */ - default void checkCanRevokeTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table) + default void checkCanRevokeTablePrivilege(Identity identity, Privilege privilege, CatalogSchemaTableName table, String revokee, boolean grantOptionFor) { denyRevokeTablePrivilege(privilege.toString(), table.toString()); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java new file mode 100644 index 0000000000000..113ec2e264815 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.statistics; + +import java.util.HashMap; +import java.util.Map; + +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; +import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; + +public final class ColumnStatistics +{ + private final Map statistics; + private static final String DATA_SIZE_STATISTIC_KEY = "data_size"; + private static final String NULLS_COUNT_STATISTIC_KEY = "nulls_count"; + private static final String DISTINCT_VALUES_STATITIC_KEY = "distinct_values_count"; + + private ColumnStatistics(Estimate dataSize, Estimate nullsCount, Estimate distinctValuesCount) + { + requireNonNull(dataSize, "dataSize can not be null"); + statistics = createStatisticsMap(dataSize, nullsCount, distinctValuesCount); + } + + private static Map createStatisticsMap(Estimate dataSize, Estimate nullsCount, Estimate distinctValuesCount) + { + Map statistics = new HashMap<>(); + statistics.put(DATA_SIZE_STATISTIC_KEY, dataSize); + statistics.put(NULLS_COUNT_STATISTIC_KEY, nullsCount); + statistics.put(DISTINCT_VALUES_STATITIC_KEY, distinctValuesCount); + return unmodifiableMap(statistics); + } + + public Estimate getDataSize() + { + return statistics.get(DATA_SIZE_STATISTIC_KEY); + } + + public Estimate getNullsCount() + { + return statistics.get(NULLS_COUNT_STATISTIC_KEY); + } + + public Estimate getDistinctValuesCount() + { + return statistics.get(DISTINCT_VALUES_STATITIC_KEY); + } + + public Map getStatistics() + { + return statistics; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Estimate dataSize = unknownValue(); + private Estimate nullsCount = unknownValue(); + private Estimate distinctValuesCount = unknownValue(); + + public Builder setDataSize(Estimate dataSize) + { + this.dataSize = requireNonNull(dataSize, "dataSize can not be null"); + return this; + } + + public Builder setNullsCount(Estimate nullsCount) + { + this.nullsCount = nullsCount; + return this; + } + + public Builder setDistinctValuesCount(Estimate distinctValuesCount) + { + this.distinctValuesCount = distinctValuesCount; + return this; + } + + public ColumnStatistics build() + { + return new ColumnStatistics(dataSize, nullsCount, distinctValuesCount); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/Estimate.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/Estimate.java new file mode 100644 index 0000000000000..27a490699fde0 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/Estimate.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.statistics; + +import java.util.Objects; +import java.util.function.Function; + +import static java.lang.Double.isNaN; + +public final class Estimate +{ + // todo eventually add some notion of statistic reliability + // Skipping for now as there hard to compute it properly and so far we do not have + // usecase for that. + + private static final double UNKNOWN_VALUE = Double.NaN; + + private final double value; + + public static final Estimate unknownValue() + { + return new Estimate(UNKNOWN_VALUE); + } + + public Estimate(double value) + { + this.value = value; + } + + public boolean isValueUnknown() + { + return isNaN(value); + } + + public double getValue() + { + return value; + } + + public Estimate map(Function mappingFunction) + { + if (isValueUnknown()) { + return this; + } + else { + return new Estimate(mappingFunction.apply(value)); + } + } + + @Override + public String toString() + { + return String.valueOf(value); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Estimate estimate = (Estimate) o; + return Double.compare(estimate.value, value) == 0; + } + + @Override + public int hashCode() + { + return Objects.hash(value); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/TableStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/TableStatistics.java new file mode 100644 index 0000000000000..9cef10d064543 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/TableStatistics.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.statistics; + +import com.facebook.presto.spi.ColumnHandle; + +import java.util.HashMap; +import java.util.Map; + +import static com.facebook.presto.spi.statistics.Estimate.unknownValue; +import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; + +public final class TableStatistics +{ + public static final TableStatistics EMPTY_STATISTICS = TableStatistics.builder().build(); + + private static final String ROW_COUNT_STATISTIC_KEY = "row_count"; + + private final Map statistics; + private final Map columnStatistics; + + public TableStatistics(Estimate rowCount, Map columnStatistics) + { + requireNonNull(rowCount, "rowCount can not be null"); + this.columnStatistics = unmodifiableMap(requireNonNull(columnStatistics, "columnStatistics can not be null")); + this.statistics = createStatisticsMap(rowCount); + } + + private static Map createStatisticsMap(Estimate rowCount) + { + Map statistics = new HashMap<>(); + statistics.put(ROW_COUNT_STATISTIC_KEY, rowCount); + return unmodifiableMap(statistics); + } + + public Estimate getRowCount() + { + return statistics.get(ROW_COUNT_STATISTIC_KEY); + } + + public Map getTableStatistics() + { + return statistics; + } + + public Map getColumnStatistics() + { + return columnStatistics; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Estimate rowCount = unknownValue(); + private Map columnStatisticsMap = new HashMap<>(); + + public Builder setRowCount(Estimate rowCount) + { + this.rowCount = requireNonNull(rowCount, "rowCount can not be null"); + return this; + } + + public Builder setColumnStatistics(ColumnHandle columnName, ColumnStatistics columnStatistics) + { + requireNonNull(columnName, "columnName can not be null"); + requireNonNull(columnStatistics, "columnStatistics can not be null"); + this.columnStatisticsMap.put(columnName, columnStatistics); + return this; + } + + public TableStatistics build() + { + return new TableStatistics(rowCount, columnStatisticsMap); + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractFixedWidthType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractFixedWidthType.java index c218e3748628f..01f9b56df912a 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractFixedWidthType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractFixedWidthType.java @@ -40,10 +40,17 @@ public final int getFixedSize() @Override public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new FixedWidthBlockBuilder( getFixedSize(), blockBuilderStatus, - fixedSize == 0 ? expectedEntries : Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / fixedSize)); + fixedSize == 0 ? expectedEntries : Math.min(expectedEntries, maxBlockSizeInBytes / fixedSize)); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractIntType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractIntType.java index e04bf07db5523..e61eae64c81bf 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractIntType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractIntType.java @@ -104,9 +104,16 @@ public int compareTo(Block leftBlock, int leftPosition, Block rightBlock, int ri @Override public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new IntArrayBlockBuilder( blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / Integer.BYTES)); + Math.min(expectedEntries, maxBlockSizeInBytes / Integer.BYTES)); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractLongType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractLongType.java index 855b508095eee..dafe4133da27d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractLongType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractLongType.java @@ -102,9 +102,16 @@ public int compareTo(Block leftBlock, int leftPosition, Block rightBlock, int ri @Override public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new LongArrayBlockBuilder( blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / Long.BYTES)); + Math.min(expectedEntries, maxBlockSizeInBytes / Long.BYTES)); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractVariableWidthType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractVariableWidthType.java index 27d070598e628..7eaaeac22bc0c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractVariableWidthType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/AbstractVariableWidthType.java @@ -31,9 +31,16 @@ protected AbstractVariableWidthType(TypeSignature signature, Class javaType) @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new VariableWidthBlockBuilder( blockBuilderStatus, - expectedBytesPerEntry == 0 ? expectedEntries : Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / expectedBytesPerEntry), + expectedBytesPerEntry == 0 ? expectedEntries : Math.min(expectedEntries, maxBlockSizeInBytes / expectedBytesPerEntry), expectedBytesPerEntry); } diff --git a/presto-main/src/main/java/com/facebook/presto/type/ArrayType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/ArrayType.java similarity index 91% rename from presto-main/src/main/java/com/facebook/presto/type/ArrayType.java rename to presto-spi/src/main/java/com/facebook/presto/spi/type/ArrayType.java index 01c27ffea225b..c4ad958fe3530 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/ArrayType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/ArrayType.java @@ -11,20 +11,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; -import com.facebook.presto.operator.scalar.CombineHashFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.AbstractArrayBlock; import com.facebook.presto.spi.block.ArrayBlockBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.type.AbstractType; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import java.util.ArrayList; @@ -32,8 +26,9 @@ import java.util.List; import static com.facebook.presto.spi.type.StandardTypes.ARRAY; -import static com.facebook.presto.type.TypeUtils.checkElementNotNull; -import static com.facebook.presto.type.TypeUtils.hashPosition; +import static com.facebook.presto.spi.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.spi.type.TypeUtils.hashPosition; +import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; public class ArrayType @@ -92,7 +87,7 @@ public long hash(Block block, int position) Block array = getObject(block, position); long hash = 0; for (int i = 0; i < array.getPositionCount(); i++) { - hash = CombineHashFunction.getHash(hash, hashPosition(elementType, array, i)); + hash = 31 * hash + hashPosition(elementType, array, i); } return hash; } @@ -210,7 +205,7 @@ public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, in @Override public List getTypeParameters() { - return ImmutableList.of(getElementType()); + return singletonList(getElementType()); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/BooleanType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/BooleanType.java index f34f6d20dd22a..b4c3ec7da3a86 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/BooleanType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/BooleanType.java @@ -41,9 +41,16 @@ public int getFixedSize() @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new ByteArrayBlockBuilder( blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / Byte.BYTES)); + Math.min(expectedEntries, maxBlockSizeInBytes / Byte.BYTES)); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/CharType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/CharType.java index cbe2950d21608..97fcb4785c4be 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/CharType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/CharType.java @@ -145,6 +145,9 @@ public void writeSlice(BlockBuilder blockBuilder, Slice value) @Override public void writeSlice(BlockBuilder blockBuilder, Slice value, int offset, int length) { + if (length > 0 && value.getByte(offset + length - 1) == ' ') { + throw new IllegalArgumentException("Slice representing Char should not have trailing spaces"); + } blockBuilder.writeBytes(value, offset, length).closeEntry(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/Chars.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/Chars.java index 275e6e76b3af5..c7900ac0badd9 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/Chars.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/Chars.java @@ -94,10 +94,10 @@ public static Slice trimSpacesAndTruncateToLength(Slice slice, int maxLength) if (maxLength < 0) { throw new IllegalArgumentException("Max length must be greater or equal than zero"); } - return truncateToLength(trimSpaces(slice), maxLength); + return truncateToLength(trimTrailingSpaces(slice), maxLength); } - public static Slice trimSpaces(Slice slice) + public static Slice trimTrailingSpaces(Slice slice) { requireNonNull(slice, "slice is null"); return slice.slice(0, sliceLengthWithoutTrailingSpaces(slice)); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/DoubleType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/DoubleType.java index 4d51b7a3fe895..6ea28e8be1493 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/DoubleType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/DoubleType.java @@ -112,9 +112,16 @@ public void writeDouble(BlockBuilder blockBuilder, double value) @Override public final BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new LongArrayBlockBuilder( blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / Double.BYTES)); + Math.min(expectedEntries, maxBlockSizeInBytes / Double.BYTES)); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/LongDecimalType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/LongDecimalType.java index 345f341cb4015..47267475b5cae 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/LongDecimalType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/LongDecimalType.java @@ -44,10 +44,17 @@ public int getFixedSize() @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new FixedWidthBlockBuilder( getFixedSize(), blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / getFixedSize())); + Math.min(expectedEntries, maxBlockSizeInBytes / getFixedSize())); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/type/MapType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java similarity index 65% rename from presto-main/src/main/java/com/facebook/presto/type/MapType.java rename to presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java index ab3baad7272e8..7557e7288f669 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/MapType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/MapType.java @@ -11,56 +11,79 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; import com.facebook.presto.spi.ConnectorSession; -import com.facebook.presto.spi.block.ArrayBlockBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; -import com.facebook.presto.spi.block.InterleavedBlockBuilder; -import com.facebook.presto.spi.type.AbstractType; -import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; -import com.facebook.presto.spi.type.TypeSignatureParameter; -import com.google.common.collect.ImmutableList; +import com.facebook.presto.spi.block.MapBlock; +import com.facebook.presto.spi.block.MapBlockBuilder; +import com.facebook.presto.spi.block.SingleMapBlock; +import java.lang.invoke.MethodHandle; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import static com.facebook.presto.type.TypeUtils.checkElementNotNull; -import static com.facebook.presto.type.TypeUtils.hashPosition; -import static com.google.common.base.Preconditions.checkArgument; +import static com.facebook.presto.spi.type.TypeUtils.checkElementNotNull; +import static com.facebook.presto.spi.type.TypeUtils.hashPosition; +import static java.lang.String.format; +import static java.util.Arrays.asList; +import static java.util.Objects.requireNonNull; public class MapType extends AbstractType { + private final boolean useNewMapBlock; + private final Type keyType; private final Type valueType; private static final String MAP_NULL_ELEMENT_MSG = "MAP comparison not supported for null value elements"; private static final int EXPECTED_BYTES_PER_ENTRY = 32; - public MapType(Type keyType, Type valueType) + private final MethodHandle keyNativeHashCode; + private final MethodHandle keyBlockHashCode; + private final MethodHandle keyBlockNativeEquals; + + public MapType(boolean useNewMapBlock, Type keyType, Type valueType, MethodHandle keyBlockNativeEquals, MethodHandle keyNativeHashCode, MethodHandle keyBlockHashCode) { super(new TypeSignature(StandardTypes.MAP, TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature())), Block.class); - checkArgument(keyType.isComparable(), "key type must be comparable"); + if (!keyType.isComparable()) { + throw new IllegalArgumentException(format("key type must be comparable, got %s", keyType)); + } + this.useNewMapBlock = useNewMapBlock; this.keyType = keyType; this.valueType = valueType; + if (useNewMapBlock) { + requireNonNull(keyBlockNativeEquals, "keyBlockNativeEquals is null"); + requireNonNull(keyNativeHashCode, "keyNativeHashCode is null"); + requireNonNull(keyBlockHashCode, "keyBlockHashCode is null"); + } + else { + if (keyBlockNativeEquals != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyBlockNativeEquals should be null."); + } + if (keyNativeHashCode != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyNativeHashCode should be null."); + } + if (keyBlockHashCode != null) { + throw new IllegalArgumentException("When useNewMapBlock is false, keyBlockHashCode should be null."); + } + } + this.keyBlockNativeEquals = keyBlockNativeEquals; + this.keyNativeHashCode = keyNativeHashCode; + this.keyBlockHashCode = keyBlockHashCode; } @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { - return new ArrayBlockBuilder( - new InterleavedBlockBuilder(getTypeParameters(), blockBuilderStatus, expectedEntries * 2, expectedBytesPerEntry), - blockBuilderStatus, - expectedEntries); + return new MapBlockBuilder(useNewMapBlock, keyType, valueType, keyBlockNativeEquals, keyNativeHashCode, keyBlockHashCode, blockBuilderStatus, expectedEntries); } @Override @@ -177,10 +200,13 @@ public Object getObjectValue(ConnectorSession session, Block block, int position return null; } - Block mapBlock = block.getObject(position, Block.class); + Block singleMapBlock = block.getObject(position, Block.class); + if (!(singleMapBlock instanceof SingleMapBlock)) { + throw new UnsupportedOperationException("Map is encoded with legacy block representation"); + } Map map = new HashMap<>(); - for (int i = 0; i < mapBlock.getPositionCount(); i += 2) { - map.put(keyType.getObjectValue(session, mapBlock, i), valueType.getObjectValue(session, mapBlock, i + 1)); + for (int i = 0; i < singleMapBlock.getPositionCount(); i += 2) { + map.put(keyType.getObjectValue(session, singleMapBlock, i), valueType.getObjectValue(session, singleMapBlock, i + 1)); } return Collections.unmodifiableMap(map); @@ -213,7 +239,7 @@ public void writeObject(BlockBuilder blockBuilder, Object value) @Override public List getTypeParameters() { - return ImmutableList.of(getKeyType(), getValueType()); + return asList(getKeyType(), getValueType()); } @Override @@ -221,4 +247,18 @@ public String getDisplayName() { return "map(" + keyType.getDisplayName() + ", " + valueType.getDisplayName() + ")"; } + + public MapBlock createBlockFromKeyValue(boolean[] mapIsNull, int[] offsets, Block keyBlock, Block valueBlock) + { + return MapBlock.fromKeyValueBlock( + useNewMapBlock, + mapIsNull, + offsets, + keyBlock, + valueBlock, + this, + keyBlockNativeEquals, + keyNativeHashCode, + keyBlockHashCode); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/ParametricType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/ParametricType.java index de49fc7a2e5c1..644df0e1331fe 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/ParametricType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/ParametricType.java @@ -19,5 +19,5 @@ public interface ParametricType { String getName(); - Type createType(List parameters); + Type createType(TypeManager typeManager, List parameters); } diff --git a/presto-main/src/main/java/com/facebook/presto/type/RowType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/RowType.java similarity index 81% rename from presto-main/src/main/java/com/facebook/presto/type/RowType.java rename to presto-spi/src/main/java/com/facebook/presto/spi/type/RowType.java index 84da31a00a964..9ca20f3904e84 100644 --- a/presto-main/src/main/java/com/facebook/presto/type/RowType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/RowType.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; @@ -21,12 +21,6 @@ import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; -import com.facebook.presto.spi.type.AbstractType; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; -import com.google.common.base.Joiner; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collections; @@ -34,8 +28,6 @@ import java.util.Optional; import static com.facebook.presto.spi.type.StandardTypes.ROW; -import static com.facebook.presto.type.TypeUtils.hashPosition; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; /** @@ -49,20 +41,33 @@ public class RowType public RowType(List fieldTypes, Optional> fieldNames) { - super(new TypeSignature( - ROW, - Lists.transform(fieldTypes, Type::getTypeSignature), - fieldNames.orElse(ImmutableList.of()).stream() - .collect(toImmutableList())), - Block.class); + super(toTypeSignature(fieldTypes, fieldNames), Block.class); - ImmutableList.Builder builder = ImmutableList.builder(); + List fields = new ArrayList<>(); for (int i = 0; i < fieldTypes.size(); i++) { int index = i; - builder.add(new RowField(fieldTypes.get(i), fieldNames.map((names) -> names.get(index)))); + fields.add(new RowField(fieldTypes.get(i), fieldNames.map((names) -> names.get(index)))); } - fields = builder.build(); - this.fieldTypes = ImmutableList.copyOf(fieldTypes); + this.fields = fields; + this.fieldTypes = fieldTypes; + } + + private static TypeSignature toTypeSignature(List fieldTypes, Optional> fieldNames) + { + int size = fieldTypes.size(); + if (size == 0) { + throw new IllegalArgumentException("Row type must have at least 1 field"); + } + + List elementTypeSignatures = new ArrayList<>(); + List literalParameters = new ArrayList<>(); + for (int i = 0; i < size; i++) { + elementTypeSignatures.add(fieldTypes.get(i).getTypeSignature()); + if (fieldNames.isPresent()) { + literalParameters.add(fieldNames.get().get(i)); + } + } + return new TypeSignature(ROW, elementTypeSignatures, literalParameters); } @Override @@ -87,17 +92,21 @@ public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, in public String getDisplayName() { // Convert to standard sql name - List fieldDisplayNames = new ArrayList<>(); + StringBuilder result = new StringBuilder(); + result.append(ROW).append('('); for (RowField field : fields) { String typeDisplayName = field.getType().getDisplayName(); if (field.getName().isPresent()) { - fieldDisplayNames.add(field.getName().get() + " " + typeDisplayName); + result.append(field.getName().get()).append(' ').append(typeDisplayName); } else { - fieldDisplayNames.add(typeDisplayName); + result.append(typeDisplayName); } + result.append(", "); } - return ROW + "(" + Joiner.on(", ").join(fieldDisplayNames) + ")"; + result.setLength(result.length() - 2); + result.append(')'); + return result.toString(); } @Override @@ -233,7 +242,7 @@ public long hash(Block block, int position) long result = 1; for (int i = 0; i < arrayBlock.getPositionCount(); i++) { Type elementType = fields.get(i).getType(); - result = 31 * result + hashPosition(elementType, arrayBlock, i); + result = 31 * result + TypeUtils.hashPosition(elementType, arrayBlock, i); } return result; } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/ShortDecimalType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/ShortDecimalType.java index abc6c11cd5fff..a6f1efa0c2e59 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/ShortDecimalType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/ShortDecimalType.java @@ -42,9 +42,16 @@ public int getFixedSize() @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new LongArrayBlockBuilder( blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / getFixedSize())); + Math.min(expectedEntries, maxBlockSizeInBytes / getFixedSize())); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/SmallintType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/SmallintType.java index dc31788f371ae..5fed7716783b3 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/SmallintType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/SmallintType.java @@ -44,9 +44,16 @@ public int getFixedSize() @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new ShortArrayBlockBuilder( blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / Short.BYTES)); + Math.min(expectedEntries, maxBlockSizeInBytes / Short.BYTES)); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/TinyintType.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/TinyintType.java index 94502d96b092b..c5ff2680af304 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/TinyintType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/TinyintType.java @@ -44,9 +44,16 @@ public int getFixedSize() @Override public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) { + int maxBlockSizeInBytes; + if (blockBuilderStatus == null) { + maxBlockSizeInBytes = BlockBuilderStatus.DEFAULT_MAX_BLOCK_SIZE_IN_BYTES; + } + else { + maxBlockSizeInBytes = blockBuilderStatus.getMaxBlockSizeInBytes(); + } return new ByteArrayBlockBuilder( blockBuilderStatus, - Math.min(expectedEntries, blockBuilderStatus.getMaxBlockSizeInBytes() / Byte.BYTES)); + Math.min(expectedEntries, maxBlockSizeInBytes / Byte.BYTES)); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeManager.java index f89458d527ae1..6069d280cece6 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeManager.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.spi.type; +import com.facebook.presto.spi.function.OperatorType; + +import java.lang.invoke.MethodHandle; import java.util.Collection; import java.util.List; import java.util.Optional; @@ -50,4 +53,6 @@ default boolean canCoerce(Type actualType, Type expectedType) boolean isTypeOnlyCoercion(Type actualType, Type expectedType); Optional coerceTypeBase(Type sourceType, String resultTypeBase); + + MethodHandle resolveOperator(OperatorType operatorType, List argumentTypes); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeSignature.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeSignature.java index e38f8a73502ec..1570d045bbb0e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeSignature.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeSignature.java @@ -20,8 +20,10 @@ import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.TreeMap; import java.util.stream.Collectors; import static java.lang.String.format; @@ -35,6 +37,13 @@ public class TypeSignature private final List parameters; private final boolean calculated; + private static final Map BASE_NAME_ALIAS_TO_CANONICAL = + new TreeMap(String.CASE_INSENSITIVE_ORDER); + + static { + BASE_NAME_ALIAS_TO_CANONICAL.put("int", StandardTypes.INTEGER); + } + public TypeSignature(String base, TypeSignatureParameter... parameters) { this(base, asList(parameters)); @@ -100,7 +109,7 @@ public static TypeSignature parseTypeSignature(String signature, Set lit return VarcharType.createUnboundedVarcharType().getTypeSignature(); } checkArgument(!literalCalculationParameters.contains(signature), "Bad type signature: '%s'", signature); - return new TypeSignature(signature, new ArrayList<>()); + return new TypeSignature(canonicalizeBaseName(signature), new ArrayList<>()); } if (signature.toLowerCase(Locale.ENGLISH).startsWith(StandardTypes.ROW + "(")) { return parseRowTypeSignature(signature, literalCalculationParameters); @@ -125,7 +134,7 @@ public static TypeSignature parseTypeSignature(String signature, Set lit if (bracketCount == 0) { verify(baseName == null, "Expected baseName to be null"); verify(parameterStart == -1, "Expected parameter start to be -1"); - baseName = signature.substring(0, i); + baseName = canonicalizeBaseName(signature.substring(0, i)); checkArgument(!literalCalculationParameters.contains(baseName), "Bad type signature: '%s'", signature); parameterStart = i + 1; } @@ -171,7 +180,7 @@ private static TypeSignature parseRowTypeSignature(String signature, Set if (bracketCount == 0) { verify(baseName == null, "Expected baseName to be null"); verify(parameterStart == -1, "Expected parameter start to be -1"); - baseName = signature.substring(0, i); + baseName = canonicalizeBaseName(signature.substring(0, i)); parameterStart = i + 1; inFieldName = true; } @@ -223,7 +232,7 @@ private static TypeSignature parseOldStyleRowTypeSignature(String signature, Set if (bracketCount == 0) { verify(baseName == null, "Expected baseName to be null"); verify(parameterStart == -1, "Expected parameter start to be -1"); - baseName = signature.substring(0, i); + baseName = canonicalizeBaseName(signature.substring(0, i)); parameterStart = i + 1; } bracketCount++; @@ -257,7 +266,7 @@ else if (c == '(') { if (baseName == null) { verify(parameters.isEmpty(), "Expected no parameters"); verify(parameterStart == -1, "Expected parameter start to be -1"); - baseName = signature.substring(0, i); + baseName = canonicalizeBaseName(signature.substring(0, i)); } parameterStart = i + 1; } @@ -385,6 +394,15 @@ private static boolean validateName(String name) return name.chars().noneMatch(c -> c == '<' || c == '>' || c == ','); } + private static String canonicalizeBaseName(String baseName) + { + String canonicalBaseName = BASE_NAME_ALIAS_TO_CANONICAL.get(baseName); + if (canonicalBaseName == null) { + return baseName; + } + return canonicalBaseName; + } + @Override public boolean equals(Object o) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java index efa34bcf0dd41..6648110ccf1b4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/type/TypeUtils.java @@ -13,13 +13,18 @@ */ package com.facebook.presto.spi.type; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; + public final class TypeUtils { + public static final int NULL_HASH_CODE = 0; + private TypeUtils() { } @@ -83,4 +88,19 @@ else if (value instanceof String) { type.writeObject(blockBuilder, value); } } + + static long hashPosition(Type type, Block block, int position) + { + if (block.isNull(position)) { + return NULL_HASH_CODE; + } + return type.hash(block, position); + } + + static void checkElementNotNull(boolean isNull, String errorMsg) + { + if (isNull) { + throw new PrestoException(NOT_SUPPORTED, errorMsg); + } + } } diff --git a/presto-spi/src/main/resources/com/facebook/presto/spi/type/zone-index.properties b/presto-spi/src/main/resources/com/facebook/presto/spi/type/zone-index.properties index 675c24744f378..2eb654f5527fe 100644 --- a/presto-spi/src/main/resources/com/facebook/presto/spi/type/zone-index.properties +++ b/presto-spi/src/main/resources/com/facebook/presto/spi/type/zone-index.properties @@ -2224,6 +2224,17 @@ 2215 Asia/Chita 2216 Asia/Srednekolymsk 2217 Pacific/Bougainville +2218 America/Fort_Nelson +2219 Asia/Barnaul +2220 Asia/Tomsk +2221 Europe/Astrakhan +2222 Europe/Kirov +2223 Europe/Ulyanovsk +2224 Asia/Yangon +2225 America/Punta_Arenas +2226 Asia/Atyrau +2227 Asia/Famagusta +2228 Europe/Saratov # Zones not supported in Java # ROC diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java b/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java index 605f7076f12d8..ea9da285e3990 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/TestPage.java @@ -87,7 +87,7 @@ public void testCompactDictionaryBlocks() int otherDictionaryUsedPositions = 30; int[] otherDictionaryIds = getDictionaryIds(positionCount, otherDictionaryUsedPositions); SliceArrayBlock dictionary3 = new SliceArrayBlock(70, createExpectedValues(70)); - DictionaryBlock randomSourceIdBlock = new DictionaryBlock(positionCount, dictionary3, otherDictionaryIds); + DictionaryBlock randomSourceIdBlock = new DictionaryBlock(dictionary3, otherDictionaryIds); Page page = new Page(commonSourceIdBlock1, randomSourceIdBlock, commonSourceIdBlock2); page.compact(); diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestArrayBlockBuilder.java b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestArrayBlockBuilder.java index 89879aed32e5c..68dc75b29084a 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestArrayBlockBuilder.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestArrayBlockBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.spi.block; +import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BigintType.BIGINT; @@ -48,7 +49,22 @@ private void testIsFull(PageBuilderStatus pageBuilderStatus) assertEquals(pageBuilderStatus.isFull(), true); } - @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Expected entry size to be .*") + //TODO we should systematically test Block::getRetainedSizeInBytes() + @Test + public void testRetainedSizeInBytes() + { + int expectedEntries = 1000; + BlockBuilder arrayBlockBuilder = new ArrayBlockBuilder(BIGINT, new BlockBuilderStatus(), expectedEntries); + long initialRetainedSize = arrayBlockBuilder.getRetainedSizeInBytes(); + for (int i = 0; i < expectedEntries; i++) { + BlockBuilder arrayElementBuilder = arrayBlockBuilder.beginBlockEntry(); + BIGINT.writeLong(arrayElementBuilder, i); + arrayBlockBuilder.closeEntry(); + } + assertTrue(arrayBlockBuilder.getRetainedSizeInBytes() >= (expectedEntries * Long.BYTES + ClassLayout.parseClass(LongArrayBlockBuilder.class).instanceSize() + initialRetainedSize)); + } + + @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Expected current entry to be closed but was opened") public void testConcurrentWriting() { BlockBuilder blockBuilder = new ArrayBlockBuilder(BIGINT, new BlockBuilderStatus(), EXPECTED_ENTRY_COUNT); diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestBlockRetainedSizeBreakdown.java b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestBlockRetainedSizeBreakdown.java new file mode 100644 index 0000000000000..7928ecc792455 --- /dev/null +++ b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestBlockRetainedSizeBreakdown.java @@ -0,0 +1,207 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.block; + +import com.facebook.presto.spi.type.Type; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import it.unimi.dsi.fastutil.Hash.Strategy; +import it.unimi.dsi.fastutil.objects.Object2LongOpenCustomHashMap; +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.TinyintType.TINYINT; +import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static io.airlift.slice.Slices.utf8Slice; +import static org.testng.Assert.assertEquals; + +public class TestBlockRetainedSizeBreakdown +{ + private static final int EXPECTED_ENTRIES = 100; + + @Test + public void testArrayBlock() + { + BlockBuilder arrayBlockBuilder = new ArrayBlockBuilder(BIGINT, new BlockBuilderStatus(), EXPECTED_ENTRIES); + for (int i = 0; i < EXPECTED_ENTRIES; i++) { + BlockBuilder arrayElementBuilder = arrayBlockBuilder.beginBlockEntry(); + writeNativeValue(BIGINT, arrayElementBuilder, castIntegerToObject(i, BIGINT)); + arrayBlockBuilder.closeEntry(); + } + checkRetainedSize(arrayBlockBuilder.build(), false); + } + + @Test + public void testByteArrayBlock() + { + BlockBuilder blockBuilder = new ByteArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + for (int i = 0; i < EXPECTED_ENTRIES; i++) { + blockBuilder.writeByte(i); + } + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testDictionaryBlock() + { + Block keyDictionaryBlock = createSliceArrayBlock(EXPECTED_ENTRIES); + int[] keyIds = new int[EXPECTED_ENTRIES]; + for (int i = 0; i < keyIds.length; i++) { + keyIds[i] = i; + } + checkRetainedSize(new DictionaryBlock(EXPECTED_ENTRIES, keyDictionaryBlock, keyIds), false); + } + + @Test + public void testFixedWidthBlock() + { + BlockBuilder blockBuilder = new FixedWidthBlockBuilder(8, new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, DOUBLE); + checkRetainedSize(blockBuilder.build(), true); + } + + @Test + public void testIntArrayBlock() + { + BlockBuilder blockBuilder = new IntArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, INTEGER); + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testInterleavedBlock() + { + BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(INTEGER, INTEGER), new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, INTEGER); + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testLongArrayBlock() + { + BlockBuilder blockBuilder = new LongArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + writeEntries(EXPECTED_ENTRIES, blockBuilder, BIGINT); + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testRunLengthEncodedBlock() + { + BlockBuilder blockBuilder = new LongArrayBlockBuilder(new BlockBuilderStatus(), 1); + writeEntries(1, blockBuilder, BIGINT); + checkRetainedSize(new RunLengthEncodedBlock(blockBuilder.build(), 1), false); + } + + @Test + public void testShortArrayBlock() + { + BlockBuilder blockBuilder = new ShortArrayBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES); + for (int i = 0; i < EXPECTED_ENTRIES; i++) { + blockBuilder.writeShort(i); + } + checkRetainedSize(blockBuilder.build(), false); + } + + @Test + public void testSliceArrayBlock() + { + checkRetainedSize(createSliceArrayBlock(EXPECTED_ENTRIES), true); + } + + @Test + public void testVariableWidthBlock() + { + BlockBuilder blockBuilder = new VariableWidthBlockBuilder(new BlockBuilderStatus(), EXPECTED_ENTRIES, 4); + writeEntries(EXPECTED_ENTRIES, blockBuilder, VARCHAR); + checkRetainedSize(blockBuilder.build(), false); + } + + private static final class ObjectStrategy + implements Strategy + { + @Override + public int hashCode(Object object) + { + return System.identityHashCode(object); + } + + @Override + public boolean equals(Object left, Object right) + { + return left == right; + } + } + + private static void checkRetainedSize(Block block, boolean getRegionCreateNewObjects) + { + AtomicLong objectSize = new AtomicLong(); + Object2LongOpenCustomHashMap trackedObjects = new Object2LongOpenCustomHashMap<>(new ObjectStrategy()); + + BiConsumer consumer = (object, size) -> { + objectSize.addAndGet(size); + trackedObjects.addTo(object, 1); + }; + + block.retainedBytesForEachPart(consumer); + assertEquals(objectSize.get(), block.getRetainedSizeInBytes()); + + Block copyBlock = block.getRegion(0, block.getPositionCount() / 2); + copyBlock.retainedBytesForEachPart(consumer); + assertEquals(objectSize.get(), block.getRetainedSizeInBytes() + copyBlock.getRetainedSizeInBytes()); + + assertEquals(trackedObjects.getLong(block), 1); + assertEquals(trackedObjects.getLong(copyBlock), 1); + trackedObjects.remove(block); + trackedObjects.remove(copyBlock); + for (long value : trackedObjects.values()) { + assertEquals(value, getRegionCreateNewObjects ? 1 : 2); + } + } + + private static void writeEntries(int expectedEntries, BlockBuilder blockBuilder, Type type) + { + for (int i = 0; i < expectedEntries; i++) { + writeNativeValue(type, blockBuilder, castIntegerToObject(i, type)); + } + } + + private static Object castIntegerToObject(int value, Type type) + { + if (type == INTEGER || type == TINYINT || type == BIGINT) { + return (long) value; + } + if (type == VARCHAR) { + return String.valueOf(value); + } + if (type == DOUBLE) { + return (double) value; + } + throw new UnsupportedOperationException(); + } + + private static Block createSliceArrayBlock(int entries) + { + Slice[] sliceArray = new Slice[entries]; + for (int i = 0; i < entries; i++) { + sliceArray[i] = utf8Slice(i + ""); + } + return new SliceArrayBlock(sliceArray.length, sliceArray); + } +} diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java index 843350ee0adb0..6f1f6ec3e2727 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestDictionaryBlockEncoding.java @@ -92,7 +92,7 @@ public void testRoundTrip() } BlockEncoding blockEncoding = new DictionaryBlockEncoding(new VariableWidthBlockEncoding()); - DictionaryBlock dictionaryBlock = new DictionaryBlock(positionCount, dictionary, ids); + DictionaryBlock dictionaryBlock = new DictionaryBlock(dictionary, ids); DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1024); blockEncoding.writeBlock(sliceOutput, dictionaryBlock); diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestMethodHandleUtil.java b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestMethodHandleUtil.java new file mode 100644 index 0000000000000..0fc5c31d1156f --- /dev/null +++ b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestMethodHandleUtil.java @@ -0,0 +1,222 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.block; + +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; +import java.util.Objects; + +import static com.facebook.presto.spi.block.MethodHandleUtil.compose; +import static com.facebook.presto.spi.block.MethodHandleUtil.methodHandle; +import static com.google.common.base.MoreObjects.toStringHelper; +import static org.testng.Assert.assertEquals; + +public class TestMethodHandleUtil +{ + // Each custom type in this test is effectively a number. + // All method handles in this test returns the product of all input parameters. + // Each method handles has distinct input types and return type. + + // The composed function is invoked once to verify that: + // * The composed function type is expected + // * Each argument is multiplied into the product exactly once. (by using prime numbers as input) + + @Test + public void testCompose2() + throws Throwable + { + MethodHandle fUS2R = methodHandle(TestMethodHandleUtil.class, "fUS2R", U.class, S1.class, S2.class); + MethodHandle fT2U = methodHandle(TestMethodHandleUtil.class, "fT2U", T1.class, T2.class); + MethodHandle composed = compose(fUS2R, fT2U); + assertEquals((R) composed.invokeExact(new T1(2), new T2(3), new S1(5), new S2(7)), new R(210)); + } + + @Test + public void testCompose2withoutS() + throws Throwable + { + MethodHandle fU2R = methodHandle(TestMethodHandleUtil.class, "fU2R", U.class); + MethodHandle fT2U = methodHandle(TestMethodHandleUtil.class, "fT2U", T1.class, T2.class); + MethodHandle composed = compose(fU2R, fT2U); + assertEquals((R) composed.invokeExact(new T1(2), new T2(3)), new R(6)); + } + + @Test + public void testCompose3() + throws Throwable + { + MethodHandle fUV2R = methodHandle(TestMethodHandleUtil.class, "fUV2R", U.class, V.class); + MethodHandle fS2U = methodHandle(TestMethodHandleUtil.class, "fS2U", S1.class, S2.class); + MethodHandle fT2V = methodHandle(TestMethodHandleUtil.class, "fT2V", T1.class, T2.class); + MethodHandle composed = compose(fUV2R, fS2U, fT2V); + assertEquals((R) composed.invokeExact(new S1(2), new S2(3), new T1(5), new T2(7)), new R(210)); + } + + public static R fU2R(U u) + { + return new R(u.getValue()); + } + + public static R fUS2R(U u, S1 s1, S2 s2) + { + return new R(u.getValue() * s1.getValue() * s2.getValue()); + } + + public static R fUV2R(U u, V v) + { + return new R(u.getValue() * v.getValue()); + } + + public static U fT2U(T1 t1, T2 t2) + { + return new U(t1.getValue() * t2.getValue()); + } + + public static U fS2U(S1 s1, S2 s2) + { + return new U(s1.getValue() * s2.getValue()); + } + + public static V fT2V(T1 t1, T2 t2) + { + return new V(t1.getValue() * t2.getValue()); + } + + public static String squareBracket(String s) + { + return "[" + s + "]"; + } + + public static String squareBracket(String s, double d) + { + return "[" + s + "," + ((long) d) + "]"; + } + + public static String curlyBracket(String s, char c) + { + return "{" + s + "=" + c + "}"; + } + + public static double sum(long x, int c) + { + return (double) x + c; + } + + private static class Base + { + private final int value; + + public Base(int value) + { + this.value = value; + } + + public int getValue() + { + return value; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Base base = (Base) o; + return value == base.value; + } + + @Override + public int hashCode() + { + return Objects.hash(value); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("value", value) + .toString(); + } + } + + private static class U + extends Base + { + public U(int value) + { + super(value); + } + } + + private static class V + extends Base + { + public V(int value) + { + super(value); + } + } + + private static class R + extends Base + { + public R(int value) + { + super(value); + } + } + + private static class S1 + extends Base + { + public S1(int value) + { + super(value); + } + } + + private static class S2 + extends Base + { + public S2(int value) + { + super(value); + } + } + + private static class T1 + extends Base + { + public T1(int value) + { + super(value); + } + } + + private static class T2 + extends Base + { + public T2(int value) + { + super(value); + } + } +} diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestingBlockEncodingSerde.java b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestingBlockEncodingSerde.java index 3ea61ea98c01f..dd58305d56b6a 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/block/TestingBlockEncodingSerde.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/block/TestingBlockEncodingSerde.java @@ -54,6 +54,8 @@ public TestingBlockEncodingSerde(TypeManager typeManager, Set factory : requireNonNull(blockEncodingFactories, "blockEncodingFactories is null")) { addBlockEncodingFactory(factory); diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/resourceGroups/TestResourceGroupId.java b/presto-spi/src/test/java/com/facebook/presto/spi/resourceGroups/TestResourceGroupId.java index 9e3e0ead45126..3a1d131c96c9a 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/resourceGroups/TestResourceGroupId.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/resourceGroups/TestResourceGroupId.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.spi.resourceGroups; +import io.airlift.json.JsonCodec; import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.AssertJUnit.assertTrue; @@ -23,9 +25,20 @@ public class TestResourceGroupId @Test public void testBasic() { + new ResourceGroupId("test_test"); new ResourceGroupId("test.test"); new ResourceGroupId(new ResourceGroupId("test"), "test"); } + @Test + public void testCodec() + { + JsonCodec codec = JsonCodec.jsonCodec(ResourceGroupId.class); + ResourceGroupId resourceGroupId = new ResourceGroupId(new ResourceGroupId("test.test"), "foo"); + assertEquals(codec.fromJson(codec.toJson(resourceGroupId)), resourceGroupId); + + assertEquals(codec.toJson(resourceGroupId), "[ \"test.test\", \"foo\" ]"); + assertEquals(codec.fromJson("[\"test.test\", \"foo\"]"), resourceGroupId); + } @Test public void testIsAncestor() diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayType.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestArrayType.java similarity index 95% rename from presto-main/src/test/java/com/facebook/presto/type/TestArrayType.java rename to presto-spi/src/test/java/com/facebook/presto/spi/type/TestArrayType.java index a5281e225bfef..ba1b0dcc4c75a 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestArrayType.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestArrayType.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; import org.testng.annotations.Test; diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/type/TestMapType.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestMapType.java new file mode 100644 index 0000000000000..59260f5fd01e3 --- /dev/null +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestMapType.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.type; + +import com.facebook.presto.spi.block.MethodHandleUtil; +import org.testng.annotations.Test; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.type.VarcharType.createVarcharType; +import static org.testng.Assert.assertEquals; + +public class TestMapType +{ + @Test + public void testMapDisplayName() + { + MapType mapType = new MapType( + true, + BIGINT, + createVarcharType(42), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation")); + assertEquals(mapType.getDisplayName(), "map(bigint, varchar(42))"); + + mapType = new MapType( + true, + BIGINT, + VARCHAR, + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation"), + MethodHandleUtil.methodHandle(TestMapType.class, "throwUnsupportedOperation")); + assertEquals(mapType.getDisplayName(), "map(bigint, varchar)"); + } + + public static void throwUnsupportedOperation() + { + throw new UnsupportedOperationException(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestRowType.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestRowType.java similarity index 56% rename from presto-main/src/test/java/com/facebook/presto/type/TestRowType.java rename to presto-spi/src/test/java/com/facebook/presto/spi/type/TestRowType.java index 337033ce9d540..d86bf78267af5 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestRowType.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestRowType.java @@ -11,14 +11,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.type; +package com.facebook.presto.spi.type; -import com.facebook.presto.spi.type.Type; import org.testng.annotations.Test; import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.block.MethodHandleUtil.methodHandle; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; @@ -30,7 +30,17 @@ public class TestRowType @Test public void testRowDisplayName() { - List types = asList(BOOLEAN, DOUBLE, new ArrayType(VARCHAR), new MapType(BOOLEAN, DOUBLE)); + List types = asList( + BOOLEAN, + DOUBLE, + new ArrayType(VARCHAR), + new MapType( + true, + BOOLEAN, + DOUBLE, + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"))); Optional> names = Optional.of(asList("bool_col", "double_col", "array_col", "map_col")); RowType row = new RowType(types, names); assertEquals( @@ -41,10 +51,25 @@ public void testRowDisplayName() @Test public void testRowDisplayNoColumnNames() { - List types = asList(BOOLEAN, DOUBLE, new ArrayType(VARCHAR), new MapType(BOOLEAN, DOUBLE)); + List types = asList( + BOOLEAN, + DOUBLE, + new ArrayType(VARCHAR), + new MapType( + true, + BOOLEAN, + DOUBLE, + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"), + methodHandle(TestMapType.class, "throwUnsupportedOperation"))); RowType row = new RowType(types, Optional.empty()); assertEquals( row.getDisplayName(), "row(boolean, double, array(varchar), map(boolean, double))"); } + + public static void throwUnsupportedOperation() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTimeZoneKey.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTimeZoneKey.java index aea8328ec592d..359b589b762d7 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTimeZoneKey.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTimeZoneKey.java @@ -191,7 +191,7 @@ public int compare(TimeZoneKey left, TimeZoneKey right) hasher.putShort(timeZoneKey.getKey()); hasher.putString(timeZoneKey.getId(), StandardCharsets.UTF_8); } - // Zone file should not (normally) be changed, so let's make is more difficult - assertEquals(hasher.hash().asLong(), 5498515770239515435L, "zone-index.properties file contents changed!"); + // Zone file should not (normally) be changed, so let's make this more difficult + assertEquals(hasher.hash().asLong(), -5839014144088293930L, "zone-index.properties file contents changed!"); } } diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTypeSignature.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTypeSignature.java index a080165c14589..7b33ca0b3c3d8 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTypeSignature.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestTypeSignature.java @@ -35,7 +35,8 @@ public class TestTypeSignature { @Test - public void parseSignatureWithLiterals() throws Exception + public void parseSignatureWithLiterals() + throws Exception { TypeSignature result = parseTypeSignature("decimal(X,42)", ImmutableSet.of("X")); assertEquals(result.getParameters().size(), 2); @@ -50,6 +51,7 @@ public void parseRowSignature() assertRowSignature( "row(a bigint,b varchar)", rowSignature(namedParameter("a", signature("bigint")), namedParameter("b", varchar()))); + assertEquals(parseTypeSignature("row(col iNt)"), parseTypeSignature("row(col integer)")); assertRowSignature( "ROW(a bigint,b varchar)", "ROW", @@ -77,6 +79,7 @@ public void parseRowSignature() "row(a decimal(p1,s1),b decimal(p2,s2))", ImmutableSet.of("p1", "s1", "p2", "s2"), rowSignature(namedParameter("a", decimal("p1", "s1")), namedParameter("b", decimal("p2", "s2")))); + assertEquals(parseTypeSignature("row(a Int(p1))"), parseTypeSignature("row(a integer(p1))")); // TODO: remove the following tests when the old style row type has been completely dropped assertOldRowSignature( @@ -96,6 +99,7 @@ public void parseRowSignature() assertOldRowSignature( "array(row('col0','col1'))", array(rowSignature(namedParameter("col0", signature("bigint")), namedParameter("col1", signature("double"))))); + assertEquals(parseTypeSignature("array(row('col'))"), parseTypeSignature("array(row('col'))")); assertOldRowSignature( "row('col0','col1'))>('col0')", rowSignature(namedParameter("col0", array( @@ -149,9 +153,13 @@ public void parseSignature() assertSignature("bigint", "bigint", ImmutableList.of()); assertSignature("boolean", "boolean", ImmutableList.of()); assertSignature("varchar", "varchar", ImmutableList.of(Integer.toString(VarcharType.UNBOUNDED_LENGTH))); + assertEquals(parseTypeSignature("int"), parseTypeSignature("integer")); assertSignature("array(bigint)", "array", ImmutableList.of("bigint")); + assertEquals(parseTypeSignature("array(int)"), parseTypeSignature("array(integer)")); + assertSignature("array(array(bigint))", "array", ImmutableList.of("array(bigint)")); + assertEquals(parseTypeSignature("array(array(int))"), parseTypeSignature("array(array(integer))")); assertSignature( "array(timestamp with time zone)", "array", diff --git a/presto-spi/src/test/java/com/facebook/presto/spi/type/TestingTypeManager.java b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestingTypeManager.java index d7e7fe577cf0b..5804b82130634 100644 --- a/presto-spi/src/test/java/com/facebook/presto/spi/type/TestingTypeManager.java +++ b/presto-spi/src/test/java/com/facebook/presto/spi/type/TestingTypeManager.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.spi.type; +import com.facebook.presto.spi.function.OperatorType; import com.google.common.collect.ImmutableList; +import java.lang.invoke.MethodHandle; import java.util.Collection; import java.util.List; import java.util.Optional; @@ -78,4 +80,10 @@ public Optional coerceTypeBase(Type sourceType, String resultTypeBase) { throw new UnsupportedOperationException(); } + + @Override + public MethodHandle resolveOperator(OperatorType operatorType, List argumentTypes) + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-sqlserver/pom.xml b/presto-sqlserver/pom.xml index 373ebe9d03b35..cf2247816fef5 100644 --- a/presto-sqlserver/pom.xml +++ b/presto-sqlserver/pom.xml @@ -3,7 +3,7 @@ presto-root com.facebook.presto - 0.175-SNAPSHOT + 0.180-SNAPSHOT 4.0.0 diff --git a/presto-teradata-functions/pom.xml b/presto-teradata-functions/pom.xml index 4ed9c788840a8..23b7a0c08fec5 100644 --- a/presto-teradata-functions/pom.xml +++ b/presto-teradata-functions/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-teradata-functions diff --git a/presto-testing-server-launcher/pom.xml b/presto-testing-server-launcher/pom.xml index 0bbc77d46fff2..3ae4056b6edd9 100644 --- a/presto-testing-server-launcher/pom.xml +++ b/presto-testing-server-launcher/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-testing-server-launcher diff --git a/presto-tests/pom.xml b/presto-tests/pom.xml index 8a10435e5bd82..3a2a35af8db8d 100644 --- a/presto-tests/pom.xml +++ b/presto-tests/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-tests @@ -67,11 +67,6 @@ discovery-server - - io.airlift - http-client - - io.airlift http-server @@ -127,6 +122,11 @@ testing + + com.squareup.okhttp3 + okhttp + + com.fasterxml.jackson.core jackson-annotations diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java index 193d13c0de182..144aab4a2c7a9 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java @@ -291,8 +291,11 @@ public void testExplainAnalyzeDDL() private void assertExplainAnalyze(@Language("SQL") String query) { String value = getOnlyElement(computeActual(query).getOnlyColumnAsSet()); + + assertTrue(value.matches("(?s:.*)CPU:.*, Input:.*, Output(?s:.*)"), format("Expected output to contain \"CPU:.*, Input:.*, Output\", but it is %s", value)); + // TODO: check that rendered plan is as expected, once stats are collected in a consistent way - assertTrue(value.contains("Cost: "), format("Expected output to contain \"Cost: \", but it is %s", value)); + // assertTrue(value.contains("Cost: "), format("Expected output to contain \"Cost: \", but it is %s", value)); } protected void assertCreateTableAsSelect(String table, @Language("SQL") String query, @Language("SQL") String rowCountQuery) @@ -482,6 +485,7 @@ public void testDelete() assertUpdate("CREATE TABLE test_delete AS SELECT * FROM orders", "SELECT count(*) FROM orders"); assertUpdate("DELETE FROM test_delete WHERE rand() < 0", 0); + assertUpdate("DELETE FROM test_delete WHERE orderkey < 0", 0); assertUpdate("DROP TABLE test_delete"); // delete with a predicate that optimizes to false diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestIntegrationSmokeTest.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestIntegrationSmokeTest.java index 00ffe6432e9ff..211b23c6269b9 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestIntegrationSmokeTest.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestIntegrationSmokeTest.java @@ -135,6 +135,20 @@ public void testDescribeTable() assertTrue(expectedColumnsPossibilities.contains(actualColumns), String.format("%s not in %s", actualColumns, expectedColumnsPossibilities)); } + @Test + public void testDuplicatedRowCreateTable() + { + assertQueryFails("CREATE TABLE test (a integer, a integer)", + "line 1:31: Column name 'a' specified more than once"); + assertQueryFails("CREATE TABLE test (a integer, orderkey integer, LIKE orders INCLUDING PROPERTIES)", + "line 1:49: Column name 'orderkey' specified more than once"); + + assertQueryFails("CREATE TABLE test (a integer, A integer)", + "line 1:31: Column name 'A' specified more than once"); + assertQueryFails("CREATE TABLE test (a integer, OrderKey integer, LIKE orders INCLUDING PROPERTIES)", + "line 1:49: Column name 'orderkey' specified more than once"); + } + private MaterializedResult getExpectedTableDescription(boolean dateSupported, boolean parametrizedVarchar) { String orderDateType; diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 6d40773310fd4..e8557d08b56ed 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -24,7 +24,6 @@ import com.facebook.presto.testing.Arguments; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.SqlIntervalDayTime; import com.facebook.presto.type.SqlIntervalYearMonth; import com.facebook.presto.util.DateTimeZoneIndex; @@ -52,7 +51,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.function.Predicate; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -88,6 +86,7 @@ import static com.facebook.presto.tests.QueryAssertions.assertEqualsIgnoreOrder; import static com.facebook.presto.tests.QueryTemplate.parameter; import static com.facebook.presto.tests.QueryTemplate.queryTemplate; +import static com.facebook.presto.tests.StructuralTestUtil.mapType; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Iterables.transform; @@ -298,11 +297,11 @@ public void testVarbinary() public void testRowFieldAccessor() { //Dereference only - assertQuery("SELECT a.col0 FROM (VALUES ROW (CAST(ROW(1, 2) AS ROW(col0 integer, col1 integer)))) AS t (a)", "SELECT 1"); + assertQuery("SELECT a.col0 FROM (VALUES ROW (CAST(ROW(1, 2) AS ROW(col0 integer, col1 int)))) AS t (a)", "SELECT 1"); assertQuery("SELECT a.col0 FROM (VALUES ROW (CAST(ROW(1.0, 2.0) AS ROW(col0 integer, col1 integer)))) AS t (a)", "SELECT 1.0"); assertQuery("SELECT a.col0 FROM (VALUES ROW (CAST(ROW(TRUE, FALSE) AS ROW(col0 boolean, col1 boolean)))) AS t (a)", "SELECT TRUE"); assertQuery("SELECT a.col1 FROM (VALUES ROW (CAST(ROW(1.0, 'kittens') AS ROW(col0 varchar, col1 varchar)))) AS t (a)", "SELECT 'kittens'"); - assertQuery("SELECT a.col2.col1 FROM (VALUES ROW(CAST(ROW(1.0, ARRAY[2], row(3, 4.0)) AS ROW(col0 double, col1 array(integer), col2 row(col0 integer, col1 double))))) t(a)", "SELECT 4.0"); + assertQuery("SELECT a.col2.col1 FROM (VALUES ROW(CAST(ROW(1.0, ARRAY[2], row(3, 4.0)) AS ROW(col0 double, col1 array(int), col2 row(col0 integer, col1 double))))) t(a)", "SELECT 4.0"); // mixture of row field reference and table field reference assertQuery("SELECT cast(row(1, t.x) as row(col0 bigint, col1 bigint)).col1 FROM (VALUES 1, 2, 3) t(x)", "SELECT * FROM (VALUES 1, 2, 3)"); @@ -588,6 +587,16 @@ public void testUnnest() "SELECT * FROM (SELECT custkey FROM orders ORDER BY orderkey LIMIT 1) CROSS JOIN (VALUES (10, 1), (20, 2), (30, 3))"); assertQuery("SELECT * FROM orders, UNNEST(ARRAY[1])", "SELECT orders.*, 1 FROM orders"); + + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN UNNEST(x) ON true", + "line .*: UNNEST on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN UNNEST(x) ON true", + "line .*: UNNEST on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN UNNEST(x) ON true", + "line .*: UNNEST on other than the right side of CROSS JOIN is not supported"); } @Test @@ -1158,6 +1167,9 @@ public void testLegacyOrderByWithOutputColumnReference() assertQueryOrdered(session, "SELECT -a AS a FROM (VALUES 1, 2) t(a) ORDER BY first_value(a+t.a*2) OVER (ORDER BY a ROWS 0 PRECEDING)", "VALUES -1, -2"); assertQueryFails(session, "SELECT a as a, a* -1 AS a FROM (VALUES -1, 0, 2) t(a) ORDER BY a", ".*'a' in ORDER BY is ambiguous"); + + // grouping + assertQueryOrdered(session, "SELECT grouping(a) as c FROM (VALUES (-1, -1), (1, 1)) AS t (a, b) GROUP BY GROUPING SETS (a, b) ORDER BY c ASC", "VALUES 0, 0, 1, 1"); } @Test @@ -1841,6 +1853,148 @@ public void testRollupOverUnion() "SELECT * FROM (VALUES (0, 5), (1, 5), (2, 6), (3, 5), (4, 5), (100, 1), (NULL, 27))"); } + @Test + public void testGrouping() + throws Exception + { + assertQuery( + "SELECT a, b as t, sum(c), grouping(a, b) + grouping(a) " + + "FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) " + + "GROUP BY GROUPING SETS ( (a), (b)) " + + "ORDER BY grouping(b) ASC", + "VALUES (NULL, 'j', 11, 3), (NULL, 'l', 7, 3), ('h', NULL, 11, 1), ('k', NULL, 7, 1)"); + + assertQuery( + "SELECT a, sum(b), grouping(a) FROM (VALUES ('h', 11, 0), ('k', 7, 0)) AS t (a, b, c) GROUP BY GROUPING SETS (a)", + "VALUES ('h', 11, 0), ('k', 7, 0)"); + + assertQuery( + "SELECT a, b, sum(c), grouping(a, b) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7) ) AS t (a, b, c) GROUP BY GROUPING SETS ( (a), (b)) HAVING grouping(a, b) > 1 ", + "VALUES (NULL, 'j', 11, 2), (NULL, 'l', 7, 2)"); + + assertQuery("SELECT a, grouping(a) * 1.0 FROM (VALUES (1) ) AS t (a) GROUP BY a", + "VALUES (1, 0.0)"); + + assertQuery("SELECT a, grouping(a), grouping(a) FROM (VALUES (1) ) AS t (a) GROUP BY a", + "VALUES (1, 0, 0)"); + + assertQuery("SELECT grouping(a) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) GROUP BY GROUPING SETS (a,c), c*2", + "VALUES (0), (1), (0), (1)"); + } + + @Test + public void testGroupingWithFortyArguments() + { + // This test ensures we correctly pick the bigint implementation version of the grouping + // function which supports up to 62 columns. Semantically it is exactly the same as + // TestGroupingOperationFunction#testMoreThanThirtyTwoArguments. That test is a little easier to + // understand and verify. + String fortyLetterSequence = "aa, ab, ac, ad, ae, af, ag, ah, ai, aj, ak, al, am, an, ao, ap, aq, ar, asa, at, au, av, aw, ax, ay, az, " + + "ba, bb, bc, bd, be, bf, bg, bh, bi, bj, bk, bl, bm, bn"; + String fortyIntegers = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, " + + "31, 32, 33, 34, 35, 36, 37, 38, 39, 40"; + // 20, 2, 13, 33, 40, 9 , 14 (corresponding indices from left to right in the above fortyLetterSequence) + String groupingSet1 = "at, ab, am, bg, bn, ai, an"; + // 28, 4, 5, 29, 31, 10 (corresponding indices from left to right in the above fortyLetterSequence) + String groupingSet2 = "bb, ad, ae, bc, be, aj"; + String query = String.format( + "SELECT grouping(%s) FROM (VALUES (%s)) AS t(%s) GROUP BY GROUPING SETS ((%s), (%s), (%s))", + fortyLetterSequence, + fortyIntegers, + fortyLetterSequence, + fortyLetterSequence, + groupingSet1, + groupingSet2); + + assertQuery(query, "VALUES (0), (822283861886), (995358664191)"); + } + + @Test + public void testGroupingInWindowFunction() + throws Exception + { + assertQuery( + "SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) as g, " + + " rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " + + " CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) as r " + + "FROM orders " + + "GROUP BY ROLLUP (orderkey, custkey) " + + "ORDER BY orderkey, custkey " + + "LIMIT 10", + "VALUES (1, 370, 172799.49, 0, 1), " + + " (1, NULL, 172799.49, 1, 1), " + + " (2, 781, 38426.09, 0, 1), " + + " (2, NULL, 38426.09, 1, 2), " + + " (3, 1234, 205654.30, 0, 1), " + + " (3, NULL, 205654.30, 1, 3), " + + " (4, 1369, 56000.91, 0, 1), " + + " (4, NULL, 56000.91, 1, 4), " + + " (5, 445, 105367.67, 0, 1), " + + " (5, NULL, 105367.67, 1, 5)"); + } + + @Test + public void testGroupingInTableSubquery() + throws Exception + { + // In addition to testing grouping() in subqueries, the following tests also + // ensure correct behavior in the case of alternating GROUPING SETS and GROUP BY + // clauses in the same plan. This is significant because grouping() with GROUP BY + // works only with a special re-write that should not happen in the presence of + // GROUPING SETS. + + // Inner query has a single GROUP BY and outer query has GROUPING SETS + assertQuery( + "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey), g " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + + " FROM orders " + + " GROUP BY orderkey, custkey " + + " ORDER BY agg_price ASC " + + " LIMIT 5) as t " + + "GROUP BY GROUPING SETS ((orderkey, custkey), g) " + + "ORDER BY outer_sum", + "VALUES (35271, 334, 874.89, 0, NULL), " + + " (28647, 1351, 924.33, 0, NULL), " + + " (58145, 862, 929.03, 0, NULL), " + + " (8354, 634, 974.04, 0, NULL), " + + " (37415, 301, 986.63, 0, NULL), " + + " (NULL, NULL, 4688.92, 3, 0)"); + + // Inner query has GROUPING SETS and outer query has GROUP BY + assertQuery( + "SELECT orderkey, custkey, g, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + + " FROM orders " + + " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + + " ORDER BY agg_price ASC " + + " LIMIT 5) as t " + + "GROUP BY orderkey, custkey, g", + "VALUES (28647, NULL, 2, 924.33, 0), " + + " (8354, NULL, 2, 974.04, 0), " + + " (37415, NULL, 2, 986.63, 0), " + + " (58145, NULL, 2, 929.03, 0), " + + " (35271, NULL, 2, 874.89, 0)"); + + // Inner query has GROUPING SETS but no grouping and outer query has a simple GROUP BY + assertQuery( + "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price " + + " FROM orders " + + " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + + " ORDER BY agg_price ASC NULLS FIRST) as t " + + "GROUP BY orderkey, custkey " + + "ORDER BY outer_sum ASC NULLS FIRST " + + "LIMIT 5", + "VALUES (35271, NULL, 874.89, 0), " + + " (28647, NULL, 924.33, 0), " + + " (58145, NULL, 929.03, 0), " + + " (8354, NULL, 974.04, 0), " + + " (37415, NULL, 986.63, 0)"); + } + @Test public void testIntersect() { @@ -2121,6 +2275,17 @@ public void testJoinWithLessThanInJoinClause() { assertQuery("SELECT n.nationkey, r.regionkey FROM region r JOIN nation n ON n.regionkey = r.regionkey AND n.name < r.name"); assertQuery("SELECT l.suppkey, n.nationkey, l.partkey, n.regionkey FROM nation n JOIN lineitem l ON l.suppkey = n.nationkey AND l.partkey < n.regionkey"); + // test with single null value in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, CAST(-1 AS BIGINT)), (0, NULL), (0, CAST(0 AS BIGINT))) t(a, b) WHERE n.regionkey - 100 < t.b AND n.nationkey = t.a", + "VALUES -1, 0"); + // test with single (first) null value in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, NULL), (0, CAST(-1 AS BIGINT)), (0, CAST(0 AS BIGINT))) t(a, b) WHERE n.regionkey - 100 < t.b AND n.nationkey = t.a", + "VALUES -1, 0"); + // test with multiple null values in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, NULL), (0, NULL), (0, CAST(-1 AS BIGINT)), (0, NULL)) t(a, b) WHERE n.regionkey - 100 < t.b AND n.nationkey = t.a", + "VALUES -1"); + // test with only null value in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, NULL)) t(a, b) WHERE n.regionkey - 100 < t.b AND n.nationkey = t.a", "SELECT 1 WHERE FALSE"); } @Test @@ -2129,6 +2294,17 @@ public void testJoinWithGreaterThanInJoinClause() { assertQuery("SELECT n.nationkey, r.regionkey FROM region r JOIN nation n ON n.regionkey = r.regionkey AND n.name > r.name AND r.regionkey = 0"); assertQuery("SELECT l.suppkey, n.nationkey, l.partkey, n.regionkey FROM nation n JOIN lineitem l ON l.suppkey = n.nationkey AND l.partkey > n.regionkey"); + // test with single null value in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, CAST(-1 AS BIGINT)), (0, NULL), (0, CAST(0 AS BIGINT))) t(a, b) WHERE n.regionkey + 100 > t.b AND n.nationkey = t.a", + "VALUES -1, 0"); + // test with single (first) null value in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, NULL), (0, CAST(-1 AS BIGINT)), (0, CAST(0 AS BIGINT))) t(a, b) WHERE n.regionkey + 100 > t.b AND n.nationkey = t.a", + "VALUES -1, 0"); + // test with multiple null values in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, NULL), (0, NULL), (0, CAST(-1 AS BIGINT)), (0, NULL)) t(a, b) WHERE n.regionkey + 100 > t.b AND n.nationkey = t.a", + "VALUES -1"); + // test with only null value in build side + assertQuery("SELECT b FROM nation n, (VALUES (0, NULL)) t(a, b) WHERE n.regionkey + 100 > t.b AND n.nationkey = t.a", "SELECT 1 WHERE FALSE"); } @Test @@ -2140,6 +2316,27 @@ public void testJoinWithLessThanOnDatesInJoinClause() "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON l.orderkey = o.orderkey AND l.shipdate < DATEADD('DAY', 10, o.orderdate)"); } + @Test + public void testJoinWithNonDeterministicLessThan() + { + MaterializedRow actualRow = getOnlyElement(computeActual( + "SELECT count(*) FROM " + + "customer c1 JOIN customer c2 ON c1.nationkey=c2.nationkey " + + "WHERE c1.custkey - RANDOM(c1.custkey) < c2.custkey").getMaterializedRows()); + assertEquals(actualRow.getFieldCount(), 1); + long actualCount = (Long) actualRow.getField(0); // this should be around ~69000 + + MaterializedRow expectedAtLeastRow = getOnlyElement(computeActual( + "SELECT count(*) FROM " + + "customer c1 JOIN customer c2 ON c1.nationkey=c2.nationkey " + + "WHERE c1.custkey < c2.custkey").getMaterializedRows()); + assertEquals(expectedAtLeastRow.getFieldCount(), 1); + long expectedAtLeastCount = (Long) expectedAtLeastRow.getField(0); // this is exactly 45022 + + // Technically non-deterministic unit test but has hopefully a next to impossible chance of a false positive + assertTrue(actualCount > expectedAtLeastCount); + } + @Test public void testSimpleJoin() { @@ -3474,11 +3671,11 @@ public void testJoinWithStatefulFilterFunction() } @Test - public void testAggregationOverRigthJoinOverSingleStreamProbe() + public void testAggregationOverRightJoinOverSingleStreamProbe() { // this should return one row since value is always 'value' // this test verifies that the two streams produced by the right join - // are handled gathered for the aggergation operator + // are handled gathered for the aggregation operator assertQueryOrdered("" + "SELECT\n" + " value\n" + @@ -3562,6 +3759,29 @@ public void testOrderByOrdinalWithWildcard() assertQueryOrdered("SELECT * FROM orders ORDER BY 1"); } + @Test + public void testOrderByWithSimilarExpressions() + { + assertQuery( + "WITH t AS (SELECT 1 x, 2 y) SELECT x, y FROM t ORDER BY x, y", + "SELECT 1, 2"); + assertQuery( + "WITH t AS (SELECT 1 x, 2 y) SELECT x, y FROM t ORDER BY x, y LIMIT 1", + "SELECT 1, 2"); + assertQuery( + "WITH t AS (SELECT 1 x, 1 y) SELECT x, y FROM t ORDER BY x, y LIMIT 1", + "SELECT 1, 1"); + assertQuery( + "WITH t AS (SELECT orderkey x, orderkey y FROM orders) SELECT x, y FROM t ORDER BY x, y LIMIT 1", + "SELECT 1, 1"); + assertQuery( + "WITH t AS (SELECT orderkey x, orderkey y FROM orders) SELECT x, y FROM t ORDER BY x, y DESC LIMIT 1", + "SELECT 1, 1"); + assertQuery( + "WITH t AS (SELECT orderkey x, totalprice y, orderkey z FROM orders) SELECT x, y, z FROM t ORDER BY x, y, z LIMIT 1", + "SELECT 1, 172799.49, 1"); + } + @Test public void testGroupByOrdinal() { @@ -4373,7 +4593,7 @@ public void testWindowMapAgg() MaterializedResult actual = computeActual("" + "SELECT map_agg(orderkey, orderpriority) OVER(PARTITION BY orderstatus) FROM\n" + "(SELECT * FROM orders ORDER BY orderkey LIMIT 5) t"); - MaterializedResult expected = resultBuilder(getSession(), new MapType(BIGINT, VarcharType.createVarcharType(1))) + MaterializedResult expected = resultBuilder(getSession(), mapType(BIGINT, VarcharType.createVarcharType(1))) .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) .row(ImmutableMap.of(1L, "5-LOW", 2L, "1-URGENT", 4L, "5-LOW")) @@ -5111,11 +5331,18 @@ public void testLargeIn() } @Test - public void testNullOnLhsOfInPredicateDisallowed() + public void testNullOnLhsOfInPredicateAllowed() { - String errorMessage = "\\QNULL values are not allowed on the probe side of SemiJoin operator\\E.*"; - assertQueryFails("SELECT NULL IN (SELECT 1)", errorMessage); - assertQueryFails("SELECT x FROM (VALUES NULL) t(x) WHERE x IN (SELECT 1)", errorMessage); + assertQuery("SELECT NULL IN (1, 2, 3)", "SELECT NULL"); + assertQuery("SELECT NULL IN (SELECT 1)", "SELECT NULL"); + assertQuery("SELECT NULL IN (SELECT 1 WHERE FALSE)", "SELECT FALSE"); + assertQuery("SELECT x FROM (VALUES NULL) t(x) WHERE x IN (SELECT 1)", "SELECT 33 WHERE FALSE"); + assertQuery("SELECT NULL IN (SELECT CAST(NULL AS BIGINT))", "SELECT NULL"); + assertQuery("SELECT NULL IN (SELECT NULL WHERE FALSE)", "SELECT FALSE"); + assertQuery("SELECT NULL IN ((SELECT 1) UNION ALL (SELECT NULL))", "SELECT NULL"); + assertQuery("SELECT x IN (SELECT TRUE) FROM (SELECT * FROM (VALUES CAST(NULL AS BOOLEAN)) t(x) WHERE (x OR NULL) IS NULL)", "SELECT NULL"); + assertQuery("SELECT x IN (SELECT 1) FROM (SELECT * FROM (VALUES CAST(NULL AS INTEGER)) t(x) WHERE (x + 10 IS NULL) OR X = 2)", "SELECT NULL"); + assertQuery("SELECT x IN (SELECT 1 WHERE FALSE) FROM (SELECT * FROM (VALUES CAST(NULL AS INTEGER)) t(x) WHERE (x + 10 IS NULL) OR X = 2)", "SELECT FALSE"); } @Test @@ -5495,6 +5722,48 @@ public void testShowColumns() format("%s does not matche neither of %s and %s", actual, expectedParametrizedVarchar, expectedUnparametrizedVarchar)); } + @Test + public void testShowStatsWithoutFromFails() + { + assertQueryFails("SHOW STATS FOR (SELECT 1)", ".*There must be exactly one table in query passed to SHOW STATS SELECT clause"); + } + + @Test + public void testShowStatsWithMultipleFromFails() + { + assertQueryFails("SHOW STATS FOR (SELECT * FROM orders, lineitem)", ".*There must be exactly one table in query passed to SHOW STATS SELECT clause"); + } + + @Test + public void testShowStatsWithGroupByFails() + { + assertQueryFails("SHOW STATS FOR (SELECT avg(totalprice) FROM orders GROUP BY clerk)", ".*GROUP BY is not supported in SHOW STATS SELECT clause"); + } + + @Test + public void testShowStatsWithHavingFails() + { + assertQueryFails("SHOW STATS FOR (SELECT avg(orderkey) FROM orders HAVING avg(orderkey) < 5)", ".*HAVING is not supported in SHOW STATS SELECT clause"); + } + + @Test + public void testShowStatsWithSelectDistinctFails() + { + assertQueryFails("SHOW STATS FOR (SELECT DISTINCT * FROM orders)", ".*DISTINCT is not supported by SHOW STATS SELECT clause"); + } + + @Test + public void testShowStatsWithSelectFunctionCallFails() + { + assertQueryFails("SHOW STATS FOR (SELECT sin(orderkey) FROM orders)", ".*Only \\* and column references are supported by SHOW STATS SELECT clause"); + } + + @Test + public void testShowStatsWithWhereFunctionCallFails() + { + assertQueryFails("SHOW STATS FOR (SELECT orderkey FROM orders WHERE sin(orderkey) > 0)", ".*Only literals, column references, comparators, is \\(not\\) null and logical operators are allowed in WHERE of SHOW STATS SELECT clause"); + } + @Test public void testAtTimeZone() { @@ -5780,6 +6049,18 @@ public void testChainedUnionsWithOrder() "SELECT orderkey FROM orders UNION (SELECT custkey FROM orders UNION SELECT linenumber FROM lineitem) UNION ALL SELECT orderkey FROM lineitem ORDER BY orderkey"); } + @Test + public void testUnionWithTopN() + { + assertQuery("SELECT * FROM (" + + " SELECT regionkey FROM nation " + + " UNION ALL " + + " SELECT nationkey FROM nation" + + ") t(a) " + + "ORDER BY a LIMIT 1", + "SELECT 0"); + } + @Test public void testUnionWithJoin() { @@ -5865,6 +6146,27 @@ public void testUnionWithAggregation() ); } + @Test + public void testUnionWithUnionAndAggregation() + { + assertQuery( + "SELECT count(*) FROM (" + + "SELECT 1 FROM nation GROUP BY regionkey " + + "UNION ALL " + + "SELECT 1 FROM (" + + " SELECT 1 FROM nation " + + " UNION ALL " + + " SELECT 1 FROM nation))"); + assertQuery( + "SELECT count(*) FROM (" + + "SELECT 1 FROM (" + + " SELECT 1 FROM nation " + + " UNION ALL " + + " SELECT 1 FROM nation)" + + "UNION ALL " + + "SELECT 1 FROM nation GROUP BY regionkey)"); + } + @Test public void testUnionWithAggregationAndTableScan() { @@ -6187,7 +6489,7 @@ public void testSemiJoin() // test multi level IN subqueries assertQuery("SELECT 1 IN (SELECT 1), 2 IN (SELECT 1) WHERE 1 IN (SELECT 1)"); - // test with subqueries on left + // test with subqueries on left assertQuery("SELECT (select 1) IN (SELECT 1)"); assertQuery("SELECT (select 2) IN (1, (SELECT 2))"); assertQuery("SELECT (2 + (select 1)) IN (SELECT 1)"); @@ -6273,6 +6575,18 @@ public void testAntiJoin() "FROM orders"); } + @Test + public void testAntiJoinNullHandling() + { + assertQuery("WITH empty AS (SELECT 1 WHERE FALSE) " + + "SELECT 3 FROM (VALUES 1) WHERE NULL NOT IN (SELECT * FROM empty)", + "VALUES 3"); + + assertQuery("WITH empty AS (SELECT 1 WHERE FALSE) " + + "SELECT x FROM (VALUES NULL) t(x) WHERE x NOT IN (SELECT * FROM empty)", + "VALUES NULL"); + } + @Test public void testSemiJoinLimitPushDown() { @@ -6288,10 +6602,13 @@ public void testSemiJoinLimitPushDown() " LIMIT 10)"); } - //Disabled till #6622 is fixed - @Test(enabled = false) + @Test public void testSemiJoinNullHandling() { + assertQuery("WITH empty AS (SELECT 1 WHERE FALSE) " + + "SELECT 3 FROM (VALUES 1) WHERE NULL IN (SELECT * FROM empty)", + "SELECT 0 WHERE FALSE"); + assertQuery("" + "SELECT orderkey\n" + " IN (\n" + @@ -6717,7 +7034,7 @@ public void testExistsSubqueryWithGroupBy() @Test public void testCorrelatedScalarSubqueries() { - String errorMsg = "Unsupported correlated subquery type"; + String errorMsg = "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"; assertQueryFails("SELECT (SELECT l.orderkey) FROM lineitem l", errorMsg); assertQueryFails("SELECT (SELECT 2 * l.orderkey) FROM lineitem l", errorMsg); @@ -6740,7 +7057,7 @@ public void testCorrelatedScalarSubqueries() assertQueryFails("SELECT * FROM lineitem l WHERE 1 = (SELECT (SELECT 2 * l.orderkey))", errorMsg); // explicit limit in subquery - assertQueryFails("SELECT (SELECT count(*) FROM (SELECT * FROM (values (7,1)) t(orderkey, value) WHERE orderkey = corr_key LIMIT 1)) FROM (values 7) t(corr_key)", errorMsg); + assertQueryFails("SELECT (SELECT count(*) FROM (VALUES (7,1)) t(orderkey, value) WHERE orderkey = corr_key LIMIT 1) FROM (values 7) t(corr_key)", errorMsg); } @Test @@ -6755,7 +7072,7 @@ public void testCorrelatedScalarSubqueriesWithCountScalarAggregationAndEqualityP assertQueryFails( "SELECT count(*) FROM nation n WHERE " + "(SELECT count(*) FROM (SELECT count(*) FROM region r WHERE n.regionkey = r.regionkey)) > 1", - "Unsupported correlated subquery type"); + "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"); // with duplicated rows assertQuery( @@ -6869,12 +7186,34 @@ public void testCorrelatedScalarSubqueriesWithScalarAggregation() " FROM nation n3 " + " WHERE n3.nationkey = n1.nationkey)" + "FROM nation n1"); + + //count in subquery + assertQuery("SELECT * " + + "FROM (VALUES (0),( 1), (2), (7)) as v1(c1) " + + "WHERE v1.c1 > (SELECT count(c1) from (VALUES (0),( 1), (2)) as v2(c1) WHERE v1.c1 = v2.c1)", + "VALUES (2), (7)"); } @Test public void testCorrelatedInPredicateSubqueries() { - String errorMsg = "Unsupported correlated subquery type"; + String errorMsg = "Unexpected node: com.facebook.presto.sql.planner.plan.ApplyNode"; + + assertQuery("SELECT orderkey, clerk IN (SELECT clerk FROM orders s WHERE s.custkey = o.custkey AND s.orderkey < o.orderkey) FROM orders o"); + assertQuery("SELECT orderkey FROM orders o WHERE clerk IN (SELECT clerk FROM orders s WHERE s.custkey = o.custkey AND s.orderkey < o.orderkey)"); + + // all cases of IN (as one test query to avoid pruning, over-eager push down) + assertQuery( + "select t1.a, t1.b, " + + " t1.b in (select t2.b " + + " from (values (2, 3), (2, 4), (3, 0), (30,NULL)) t2(a, b) " + + " where t1.a - 5 <= t2.a and t2.a <= t1.a and 0 <= t2.a) " + + "from (values (1,1), (2,4), (3,5), (4,NULL), (30,2), (40,NULL) ) t1(a, b) " + + "order by t1.a", + "VALUES (1,1,FALSE), (2,4,TRUE), (3,5,FALSE), (4,NULL,NULL), (30,2,NULL), (40,NULL,FALSE)"); + + // subquery with limit (correlated filter below any unhandled node type) + assertQueryFails("SELECT orderkey FROM orders o WHERE clerk IN (SELECT clerk FROM orders s WHERE s.custkey = o.custkey AND s.orderkey < o.orderkey ORDER BY 1 LIMIT 1)", errorMsg); assertQueryFails("SELECT 1 IN (SELECT l.orderkey) FROM lineitem l", errorMsg); assertQueryFails("SELECT 1 IN (SELECT 2 * l.orderkey) FROM lineitem l", errorMsg); @@ -6890,7 +7229,9 @@ public void testCorrelatedInPredicateSubqueries() assertQueryFails("SELECT * FROM lineitem l1 JOIN lineitem l2 ON l1.orderkey IN (SELECT l2.orderkey)", errorMsg); // subrelation - assertQueryFails("SELECT * FROM lineitem l WHERE (SELECT * FROM (SELECT 1 IN (SELECT 2 * l.orderkey)))", errorMsg); + assertQueryFails( + "SELECT * FROM lineitem l WHERE (SELECT * FROM (SELECT 1 IN (SELECT 2 * l.orderkey)))", + "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"); // two level of nesting assertQueryFails("SELECT * FROM lineitem l WHERE true IN (SELECT 1 IN (SELECT 2 * l.orderkey))", errorMsg); @@ -6936,11 +7277,11 @@ public void testCorrelatedExistsSubqueriesWithEqualityPredicatesInWhere() assertQueryFails( "SELECT count(*) FROM orders o " + "WHERE EXISTS (SELECT avg(l.orderkey) FROM lineitem l WHERE o.orderkey = l.orderkey GROUP BY l.linenumber)", - "Unsupported correlated subquery type"); + "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"); assertQueryFails( "SELECT count(*) FROM orders o " + "WHERE EXISTS (SELECT count(*) FROM lineitem l WHERE o.orderkey = l.orderkey HAVING count(*) > 3)", - "Unsupported correlated subquery type"); + "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"); // with duplicated rows assertQuery( @@ -6973,6 +7314,12 @@ public void testCorrelatedExistsSubqueriesWithEqualityPredicatesInWhere() assertQuery( "SELECT count(*) FROM orders o WHERE (SELECT * FROM (SELECT EXISTS(SELECT 1 WHERE o.orderkey = 0)))", "SELECT count(*) FROM orders o WHERE o.orderkey = 0"); + + // not exists + assertQuery( + "SELECT count(*) FROM customer WHERE NOT EXISTS(SELECT * FROM orders WHERE orders.custkey=customer.custkey)", + "VALUES 500" + ); } @Test @@ -7036,7 +7383,7 @@ public void testCorrelatedExistsSubqueries() @Test public void testUnsupportedCorrelatedExistsSubqueries() { - String errorMsg = "Unsupported correlated subquery type"; + String errorMsg = "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"; assertQueryFails("SELECT EXISTS(SELECT 1 WHERE l.orderkey > 0 OR l.orderkey != 3) FROM lineitem l", errorMsg); assertQueryFails("SELECT count(*) FROM lineitem l WHERE EXISTS(SELECT 1 WHERE l.orderkey > 0 OR l.orderkey != 3)", errorMsg); @@ -8286,17 +8633,7 @@ public Object[][] qualifiedComparisonsCornerCases() parameter("quantifier").of("ALL", "ANY"), parameter("value").of("1", "NULL"), parameter("operator").of("=", "!=", "<", ">", "<=", ">=")); - //the following are disabled till #6622 is fixed - List excludedInPredicateQueries = ImmutableList.of( - "SELECT NULL != ALL (SELECT * FROM (SELECT 1 WHERE false))", - "SELECT NULL = ANY (SELECT * FROM (SELECT 1 WHERE false))", - "SELECT NULL != ALL (SELECT * FROM (SELECT CAST(NULL AS INTEGER)))", - "SELECT NULL = ANY (SELECT * FROM (SELECT CAST(NULL AS INTEGER)))", - "SELECT NULL = ANY (SELECT * FROM (VALUES (1), (NULL)))", - "SELECT NULL != ALL (SELECT * FROM (VALUES (1), (NULL)))" - ); - Predicate isExcluded = excludedInPredicateQueries::contains; - return toArgumentsArrays(queries.filter(isExcluded.negate()).map(Arguments::of)); + return toArgumentsArrays(queries.map(Arguments::of)); } @Test @@ -8465,7 +8802,7 @@ public void testSubqueriesWithDisjunction() "SELECT (SELECT true FROM (SELECT 1) t(a) WHERE a = nationkey) " + "FROM nation " + "WHERE (SELECT true FROM (SELECT 1) t(a) WHERE a = nationkey) OR TRUE", - "Unsupported correlated subquery type"); + "Unexpected node: com.facebook.presto.sql.planner.plan.LateralJoinNode"); } @Test @@ -8484,4 +8821,112 @@ public void testAssignUniqueId() "WHERE a = 1)", "VALUES 3008750"); } + + @Test + public void testAggregationPushedBelowOuterJoin() + { + assertQuery( + "SELECT * " + + "FROM nation n1 " + + "WHERE (n1.nationkey > ( " + + "SELECT avg(nationkey) " + + "FROM nation n2 " + + "WHERE n1.regionkey=n2.regionkey))" + ); + assertQuery( + "SELECT max(name), min(name), count(nationkey) + 1, count(nationkey) " + + "FROM (SELECT DISTINCT regionkey FROM region) as r1 " + + "LEFT JOIN " + + "nation " + + "ON r1.regionkey = nation.regionkey " + + "GROUP BY r1.regionkey " + + "HAVING sum(nationkey) < 20"); + + assertQuery( + "SELECT DISTINCT r1.regionkey " + + "FROM (SELECT regionkey FROM region INTERSECT SELECT regionkey FROM region where regionkey < 4) as r1 " + + "LEFT JOIN " + + "nation " + + "ON r1.regionkey = nation.regionkey"); + + assertQuery( + "SELECT max(nationkey) " + + "FROM (SELECT regionkey FROM region EXCEPT SELECT regionkey FROM region where regionkey < 4) as r1 " + + "LEFT JOIN " + + "nation " + + "ON r1.regionkey = nation.regionkey " + + "GROUP BY r1.regionkey"); + + assertQuery( + "SELECT max(nationkey) " + + "FROM (VALUES CAST (1 AS BIGINT)) v1(col1) " + + "LEFT JOIN " + + "nation " + + "ON v1.col1 = nation.regionkey " + + "GROUP BY v1.col1", + "VALUES 24"); + } + + @Test + public void testLateralJoin() + { + assertQuery( + "SELECT name FROM nation, LATERAL (SELECT 1 WHERE false)", + "SELECT 1 WHERE false"); + + assertQuery( + "SELECT name FROM nation, LATERAL (SELECT 1)", + "SELECT name FROM nation"); + + assertQuery( + "SELECT nationkey, a FROM nation, LATERAL (SELECT max(region.name) FROM region WHERE region.regionkey <= nation.regionkey) t(a) ORDER BY nationkey LIMIT 1", + "VALUES (0, 'AFRICA')"); + + assertQuery( + "SELECT nationkey, a FROM nation, LATERAL (SELECT region.name || '_' FROM region WHERE region.regionkey = nation.regionkey) t(a) ORDER BY nationkey LIMIT 1", + "VALUES (0, 'AFRICA_')"); + + assertQuery( + "SELECT nationkey, a, b, name FROM nation, LATERAL (SELECT nationkey + 2 AS a), LATERAL (SELECT a * -1 AS b) ORDER BY b LIMIT 1", + "VALUES (24, 26, -26, 'UNITED STATES')"); + + assertQuery( + "SELECT * FROM region r, LATERAL (SELECT * FROM nation) n WHERE n.regionkey = r.regionkey", + "SELECT * FROM region, nation WHERE nation.regionkey = region.regionkey"); + assertQuery( + "SELECT * FROM region, LATERAL (SELECT * FROM nation WHERE nation.regionkey = region.regionkey)", + "SELECT * FROM region, nation WHERE nation.regionkey = region.regionkey"); + + assertQuery( + "SELECT quantity, extendedprice, avg_price, low, high " + + "FROM lineitem, " + + "LATERAL (SELECT extendedprice / quantity AS avg_price) average_price, " + + "LATERAL (SELECT avg_price * 0.9 AS low) lower_bound, " + + "LATERAL (SELECT avg_price * 1.1 AS high) upper_bound " + + "ORDER BY extendedprice, quantity LIMIT 1", + "VALUES (1.0, 904.0, 904.0, 813.6, 994.400)"); + + assertQuery( + "SELECT y FROM (VALUES array[2, 3]) a(x) CROSS JOIN LATERAL(SELECT x[1]) b(y)", + "SELECT 2"); + assertQuery( + "SELECT * FROM (VALUES 2) a(x) CROSS JOIN LATERAL(SELECT x + 1)", + "SELECT 2, 3"); + assertQuery( + "SELECT * FROM (VALUES 2) a(x) CROSS JOIN LATERAL(SELECT x)", + "SELECT 2, 2"); + assertQuery( + "SELECT * FROM (VALUES 2) a(x) CROSS JOIN LATERAL(SELECT x, x + 1)", + "SELECT 2, 2, 3"); + + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN LATERAL(VALUES x) ON true", + "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN LATERAL(VALUES x) ON true", + "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); + assertQueryFails( + "SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN LATERAL(VALUES x) ON true", + "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); + } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 1ceb27a2415d4..a9cfeb63ee691 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -14,6 +14,8 @@ package com.facebook.presto.tests; import com.facebook.presto.Session; +import com.facebook.presto.cost.CoefficientBasedCostCalculator; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.security.AccessDeniedException; import com.facebook.presto.spi.type.Type; @@ -28,6 +30,7 @@ import com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilege; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; import org.intellij.lang.annotations.Language; import org.testng.SkipException; import org.testng.annotations.AfterClass; @@ -42,6 +45,7 @@ import static com.facebook.presto.sql.SqlFormatter.formatSql; import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.nullToEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.testing.Closeables.closeAllRuntimeException; import static java.lang.String.format; @@ -56,6 +60,7 @@ public abstract class AbstractTestQueryFramework private QueryRunner queryRunner; private H2QueryRunner h2QueryRunner; private SqlParser sqlParser; + private CostCalculator costCalculator; protected AbstractTestQueryFramework(QueryRunnerSupplier supplier) { @@ -69,6 +74,7 @@ public void init() queryRunner = queryRunnerSupplier.get(); h2QueryRunner = new H2QueryRunner(); sqlParser = new SqlParser(); + costCalculator = new CoefficientBasedCostCalculator(queryRunner.getMetadata()); } @AfterClass(alwaysRun = true) @@ -167,20 +173,19 @@ protected void assertUpdate(Session session, @Language("SQL") String sql, long c QueryAssertions.assertUpdate(queryRunner, session, sql, OptionalLong.of(count)); } + protected void assertQueryFailsEventually(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, Duration timeout) + { + QueryAssertions.assertQueryFailsEventually(queryRunner, getSession(), sql, expectedMessageRegExp, timeout); + } + protected void assertQueryFails(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { - assertQueryFails(getSession(), sql, expectedMessageRegExp); + QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp); } protected void assertQueryFails(Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { - try { - queryRunner.execute(session, sql); - fail(format("Expected query to fail: %s", sql)); - } - catch (RuntimeException ex) { - assertExceptionMessage(sql, ex, expectedMessageRegExp); - } + QueryAssertions.assertQueryFails(queryRunner, session, sql, expectedMessageRegExp); } protected void assertAccessAllowed(@Language("SQL") String sql, TestingPrivilege... deniedPrivileges) @@ -239,7 +244,7 @@ protected void assertTableColumnNames(String tableName, String... columnNames) private static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex) { - if (!exception.getMessage().matches(regex)) { + if (!nullToEmpty(exception.getMessage()).matches(regex)) { fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception); } } @@ -296,6 +301,7 @@ private QueryExplainer getQueryExplainer() metadata, queryRunner.getAccessControl(), sqlParser, + costCalculator, ImmutableMap.of()); } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java index ca92bb2b3543f..6f83ed817618b 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestingPrestoClient.java @@ -27,11 +27,8 @@ import com.google.common.base.Function; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.http.client.HttpClient; -import io.airlift.http.client.HttpClientConfig; -import io.airlift.http.client.jetty.JettyHttpClient; -import io.airlift.json.JsonCodec; import io.airlift.units.Duration; +import okhttp3.OkHttpClient; import org.intellij.lang.annotations.Language; import java.io.Closeable; @@ -45,51 +42,44 @@ import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.transform; -import static io.airlift.json.JsonCodec.jsonCodec; import static java.util.Objects.requireNonNull; public abstract class AbstractTestingPrestoClient implements Closeable { - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); - private final TestingPrestoServer prestoServer; private final Session defaultSession; - private final HttpClient httpClient; + private final OkHttpClient httpClient = new OkHttpClient(); protected AbstractTestingPrestoClient(TestingPrestoServer prestoServer, Session defaultSession) { this.prestoServer = requireNonNull(prestoServer, "prestoServer is null"); this.defaultSession = requireNonNull(defaultSession, "defaultSession is null"); - - this.httpClient = new JettyHttpClient( - new HttpClientConfig() - .setConnectTimeout(new Duration(1, TimeUnit.DAYS)) - .setIdleTimeout(new Duration(10, TimeUnit.DAYS))); } @Override public void close() { - this.httpClient.close(); + httpClient.dispatcher().executorService().shutdown(); + httpClient.connectionPool().evictAll(); } protected abstract ResultsSession getResultSession(Session session); - public T execute(@Language("SQL") String sql) + public ResultWithQueryId execute(@Language("SQL") String sql) { return execute(defaultSession, sql); } - public T execute(Session session, @Language("SQL") String sql) + public ResultWithQueryId execute(Session session, @Language("SQL") String sql) { ResultsSession resultsSession = getResultSession(session); ClientSession clientSession = toClientSession(session, prestoServer.getBaseUrl(), true, new Duration(2, TimeUnit.MINUTES)); - try (StatementClient client = new StatementClient(httpClient, QUERY_RESULTS_CODEC, clientSession, sql)) { + try (StatementClient client = new StatementClient(httpClient, clientSession, sql)) { while (client.isValid()) { QueryResults results = client.current(); @@ -106,7 +96,8 @@ public T execute(Session session, @Language("SQL") String sql) resultsSession.setUpdateCount(results.getUpdateCount()); } - return resultsSession.build(client.getSetSessionProperties(), client.getResetSessionProperties()); + T result = resultsSession.build(client.getSetSessionProperties(), client.getResetSessionProperties()); + return new ResultWithQueryId<>(results.getId(), result); } QueryError error = client.finalResults().getError(); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java index 7c31e721005fe..17144487a9791 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.QueryManager; import com.facebook.presto.metadata.AllNodes; @@ -25,7 +26,9 @@ import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.QueryId; import com.facebook.presto.sql.parser.SqlParserOptions; +import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.testing.TestingAccessControlManager; @@ -225,6 +228,12 @@ public Metadata getMetadata() return coordinator.getMetadata(); } + @Override + public CostCalculator getCostCalculator() + { + return coordinator.getCostCalculator(); + } + @Override public TestingAccessControlManager getAccessControl() { @@ -323,7 +332,7 @@ public MaterializedResult execute(@Language("SQL") String sql) { lock.readLock().lock(); try { - return prestoClient.execute(sql); + return prestoClient.execute(sql).getResult(); } finally { lock.readLock().unlock(); @@ -332,6 +341,17 @@ public MaterializedResult execute(@Language("SQL") String sql) @Override public MaterializedResult execute(Session session, @Language("SQL") String sql) + { + lock.readLock().lock(); + try { + return prestoClient.execute(session, sql).getResult(); + } + finally { + lock.readLock().unlock(); + } + } + + public ResultWithQueryId executeWithQueryId(Session session, @Language("SQL") String sql) { lock.readLock().lock(); try { @@ -342,6 +362,16 @@ public MaterializedResult execute(Session session, @Language("SQL") String sql) } } + public QueryInfo getQueryInfo(QueryId queryId) + { + return coordinator.getQueryManager().getQueryInfo(queryId); + } + + public Plan getQueryPlan(QueryId queryId) + { + return coordinator.getQueryManager().getQueryPlan(queryId); + } + @Override public Lock getExclusiveLock() { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java b/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java index 9ba497f62aa80..2a03128eb4af2 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/PlanDeterminismChecker.java @@ -62,7 +62,7 @@ private String getPlanText(Session session, String sql) { return localQueryRunner.inTransaction(session, transactionSession -> { Plan plan = localQueryRunner.createPlan(transactionSession, sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED); - return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), localQueryRunner.getMetadata(), transactionSession); + return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), localQueryRunner.getMetadata(), localQueryRunner.getCostCalculator(), transactionSession); }); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java index 13b924a83e094..51bb93f135947 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java @@ -28,10 +28,14 @@ import java.util.List; import java.util.OptionalLong; +import java.util.function.Supplier; +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.airlift.units.Duration.nanosSince; import static java.lang.String.format; import static java.util.Locale.ENGLISH; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -90,7 +94,7 @@ public static void assertQuery(QueryRunner actualQueryRunner, expectedResults = h2QueryRunner.execute(session, expected, actualResults.getTypes()); } catch (RuntimeException ex) { - fail("Execution of 'expected' query failed: " + actual, ex); + fail("Execution of 'expected' query failed: " + expected, ex); } log.info("FINISHED in presto: %s, h2: %s, total: %s", actualTime, nanosSince(expectedStart), nanosSince(start)); @@ -152,6 +156,23 @@ public static void assertEqualsIgnoreOrder(Iterable actual, Iterable expec } } + public static void assertContainsEventually(Supplier all, MaterializedResult expectedSubset, Duration timeout) + { + long start = System.nanoTime(); + while (!Thread.currentThread().isInterrupted()) { + try { + assertContains(all.get(), expectedSubset); + return; + } + catch (AssertionError e) { + if (nanosSince(start).compareTo(timeout) > 0) { + throw e; + } + } + sleepUninterruptibly(50, MILLISECONDS); + } + } + public static void assertContains(MaterializedResult all, MaterializedResult expectedSubset) { for (MaterializedRow row : expectedSubset.getMaterializedRows()) { @@ -166,6 +187,41 @@ public static void assertContains(MaterializedResult all, MaterializedResult exp } } + protected static void assertQueryFailsEventually(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, Duration timeout) + { + long start = System.nanoTime(); + while (!Thread.currentThread().isInterrupted()) { + try { + assertQueryFails(queryRunner, session, sql, expectedMessageRegExp); + return; + } + catch (AssertionError e) { + if (nanosSince(start).compareTo(timeout) > 0) { + throw e; + } + } + sleepUninterruptibly(50, MILLISECONDS); + } + } + + protected static void assertQueryFails(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) + { + try { + queryRunner.execute(session, sql); + fail(format("Expected query to fail: %s", sql)); + } + catch (RuntimeException ex) { + assertExceptionMessage(sql, ex, expectedMessageRegExp); + } + } + + private static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex) + { + if (!nullToEmpty(exception.getMessage()).matches(regex)) { + fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception); + } + } + public static void copyTpchTables( QueryRunner queryRunner, String sourceCatalog, diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/ResultWithQueryId.java b/presto-tests/src/main/java/com/facebook/presto/tests/ResultWithQueryId.java new file mode 100644 index 0000000000000..f2e8fbe0edabd --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/ResultWithQueryId.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +public class ResultWithQueryId +{ + private final String queryId; + private final T result; + + public ResultWithQueryId(String queryId, T result) + { + this.queryId = queryId; + this.result = result; + } + + public String getQueryId() + { + return queryId; + } + + public T getResult() + { + return result; + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java index 7e808ec4ac186..1271f4c3d235c 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.metadata.AllNodes; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.QualifiedObjectName; @@ -82,7 +83,7 @@ public MaterializedResult execute(@Language("SQL") String sql) { lock.readLock().lock(); try { - return prestoClient.execute(sql); + return prestoClient.execute(sql).getResult(); } finally { lock.readLock().unlock(); @@ -94,7 +95,7 @@ public MaterializedResult execute(Session session, @Language("SQL") String sql) { lock.readLock().lock(); try { - return prestoClient.execute(session, sql); + return prestoClient.execute(session, sql).getResult(); } finally { lock.readLock().unlock(); @@ -132,6 +133,12 @@ public Metadata getMetadata() return server.getMetadata(); } + @Override + public CostCalculator getCostCalculator() + { + return server.getCostCalculator(); + } + @Override public TestingAccessControlManager getAccessControl() { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java b/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java index a42b15a76697f..1522c74193f72 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/StructuralTestUtil.java @@ -13,13 +13,21 @@ */ package com.facebook.presto.tests; +import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.Decimals; +import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.spi.type.TypeSignatureParameter; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; @@ -31,6 +39,12 @@ public final class StructuralTestUtil { + private static final TypeManager TYPE_MANAGER = new TypeRegistry(); + static { + // associate TYPE_MANAGER with a function registry + new FunctionRegistry(TYPE_MANAGER, new BlockEncodingManager(TYPE_MANAGER), new FeaturesConfig()); + } + private StructuralTestUtil() {} public static boolean arrayBlocksEqual(Type elementType, Block block1, Block block2) @@ -79,23 +93,29 @@ public static Block arrayBlockOf(Type elementType, Object... values) public static Block mapBlockOf(Type keyType, Type valueType, Object key, Object value) { - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueType), new BlockBuilderStatus(), 1024); - appendToBlockBuilder(keyType, key, blockBuilder); - appendToBlockBuilder(valueType, value, blockBuilder); - return blockBuilder.build(); + MapType mapType = mapType(keyType, valueType); + BlockBuilder blockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 10); + BlockBuilder singleMapBlockWriter = blockBuilder.beginBlockEntry(); + appendToBlockBuilder(keyType, key, singleMapBlockWriter); + appendToBlockBuilder(valueType, value, singleMapBlockWriter); + blockBuilder.closeEntry(); + return mapType.getObject(blockBuilder, 0); } public static Block mapBlockOf(Type keyType, Type valueType, Object[] keys, Object[] values) { checkArgument(keys.length == values.length, "keys/values must have the same length"); - BlockBuilder blockBuilder = new InterleavedBlockBuilder(ImmutableList.of(keyType, valueType), new BlockBuilderStatus(), 1024); + MapType mapType = mapType(keyType, valueType); + BlockBuilder blockBuilder = mapType.createBlockBuilder(new BlockBuilderStatus(), 10); + BlockBuilder singleMapBlockWriter = blockBuilder.beginBlockEntry(); for (int i = 0; i < keys.length; i++) { Object key = keys[i]; Object value = values[i]; - appendToBlockBuilder(keyType, key, blockBuilder); - appendToBlockBuilder(valueType, value, blockBuilder); + appendToBlockBuilder(keyType, key, singleMapBlockWriter); + appendToBlockBuilder(valueType, value, singleMapBlockWriter); } - return blockBuilder.build(); + blockBuilder.closeEntry(); + return mapType.getObject(blockBuilder, 0); } public static Block rowBlockOf(List parameterTypes, Object... values) @@ -130,4 +150,11 @@ public static Block decimalMapBlockOf(DecimalType type, BigDecimal decimal) return mapBlockOf(type, type, sliceDecimal, sliceDecimal); } } + + public static MapType mapType(Type keyType, Type valueType) + { + return (MapType) TYPE_MANAGER.getParameterizedType(StandardTypes.MAP, ImmutableList.of( + TypeSignatureParameter.of(keyType.getTypeSignature()), + TypeSignatureParameter.of(valueType.getTypeSignature()))); + } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java index 6e4c4c6469209..ca2cb4f60c766 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java @@ -18,14 +18,14 @@ import com.facebook.presto.client.IntervalYearMonth; import com.facebook.presto.client.QueryResults; import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.TimeZoneKey; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; -import com.facebook.presto.type.ArrayType; -import com.facebook.presto.type.MapType; import com.facebook.presto.type.SqlIntervalDayTime; import com.facebook.presto.type.SqlIntervalYearMonth; import com.google.common.base.Function; @@ -238,6 +238,9 @@ else if (type instanceof MapType) { else if (type instanceof DecimalType) { return new BigDecimal((String) value); } + else if (type.getTypeSignature().getBase().equals("ObjectId")) { + return value; + } else { throw new AssertionError("unhandled type: " + type); } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/datatype/DataType.java b/presto-tests/src/main/java/com/facebook/presto/tests/datatype/DataType.java index f60231db977d9..d1ac4dab0f41c 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/datatype/DataType.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/datatype/DataType.java @@ -23,7 +23,6 @@ import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.google.common.base.Strings.padEnd; import static java.util.Optional.empty; -import static java.util.Optional.of; public class DataType { @@ -39,7 +38,7 @@ public static DataType varcharDataType(int size) public static DataType varcharDataType(int size, String properties) { - return varcharDataType(of(size), properties); + return varcharDataType(Optional.of(size), properties); } public static DataType varcharDataType() diff --git a/presto-hive-cdh4/src/test/java/com/facebook/presto/hive/TestHiveClient.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metric.java similarity index 50% rename from presto-hive-cdh4/src/test/java/com/facebook/presto/hive/TestHiveClient.java rename to presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metric.java index 0bde532d2d4c2..bf3d0205af4de 100644 --- a/presto-hive-cdh4/src/test/java/com/facebook/presto/hive/TestHiveClient.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/Metric.java @@ -11,20 +11,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.hive; +package com.facebook.presto.tests.statistics; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; +import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.spi.statistics.Estimate; -@Test(groups = "hive") -public class TestHiveClient - extends AbstractTestHiveClient +import java.util.function.Function; + +public enum Metric { - @Parameters({"hive.cdh4.metastoreHost", "hive.cdh4.metastorePort", "hive.cdh4.databaseName", "hive.cdh4.timeZone"}) - @BeforeClass - public void initialize(String host, int port, String databaseName, String timeZone) + OUTPUT_ROW_COUNT(PlanNodeCost::getOutputRowCount), + OUTPUT_SIZE_BYTES(PlanNodeCost::getOutputSizeInBytes); + + private final Function extractor; + + Metric(Function extractor) + { + this.extractor = extractor; + } + + Estimate getValue(PlanNodeCost cost) { - setup(host, port, databaseName, timeZone); + return extractor.apply(cost); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java new file mode 100644 index 0000000000000..053b758602b1f --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparator.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.statistics; + +import com.facebook.presto.cost.PlanNodeCost; +import com.facebook.presto.execution.StageInfo; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.planner.planPrinter.PlanNodeStats; +import com.facebook.presto.sql.planner.planPrinter.PlanNodeStatsSummarizer; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BinaryOperator; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.facebook.presto.execution.StageInfo.getAllStages; +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.util.MoreMaps.mergeMaps; +import static com.google.common.collect.Maps.transformValues; +import static java.util.Arrays.asList; + +public class MetricComparator +{ + private final List metrics = asList(Metric.values()); + private final double tolerance = 0.1; + + public List getMetricComparisons(Plan queryPlan, StageInfo outputStageInfo) + { + return metrics.stream().flatMap(metric -> { + Map estimates = queryPlan.getPlanNodeCosts(); + Map actuals = extractActualCosts(outputStageInfo); + return estimates.entrySet().stream().map(entry -> { + // todo refactor to stay in PlanNodeId domain ???? + PlanNode node = planNodeForId(queryPlan, entry.getKey()); + PlanNodeCost estimate = entry.getValue(); + Optional execution = Optional.ofNullable(actuals.get(node.getId())); + return createMetricComparison(metric, node, estimate, execution); + }); + }).collect(Collectors.toList()); + } + + private PlanNode planNodeForId(Plan queryPlan, PlanNodeId id) + { + return searchFrom(queryPlan.getRoot()) + .where(node -> node.getId().equals(id)) + .findOnlyElement(); + } + + private Map extractActualCosts(StageInfo outputStageInfo) + { + Stream> stagesStatsStream = + getAllStages(Optional.of(outputStageInfo)).stream() + .map(PlanNodeStatsSummarizer::aggregatePlanNodeStats); + + Map mergedStats = mergeStats(stagesStatsStream); + return transformValues(mergedStats, this::toPlanNodeCost); + } + + private Map mergeStats(Stream> stagesStatsStream) + { + BinaryOperator allowNoDuplicates = (a, b) -> { + throw new IllegalArgumentException("PlanNodeIds must be unique"); + }; + return mergeMaps(stagesStatsStream, allowNoDuplicates); + } + + private PlanNodeCost toPlanNodeCost(PlanNodeStats operatorStats) + { + return PlanNodeCost.builder() + .setOutputRowCount(new Estimate(operatorStats.getPlanNodeOutputPositions())) + .setOutputSizeInBytes(new Estimate(operatorStats.getPlanNodeOutputDataSize().toBytes())) + .build(); + } + + private MetricComparison createMetricComparison(Metric metric, PlanNode node, PlanNodeCost estimate, Optional execution) + { + Optional estimatedCost = asOptional(metric.getValue(estimate)); + Optional executionCost = execution.flatMap(e -> asOptional(metric.getValue(e))); + return new MetricComparison(node, metric, estimatedCost, executionCost, tolerance); + } + + private Optional asOptional(Estimate estimate) + { + return estimate.isValueUnknown() ? Optional.empty() : Optional.of(estimate.getValue()); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparison.java b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparison.java new file mode 100644 index 0000000000000..3ab2b50682d7b --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/statistics/MetricComparison.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.statistics; + +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.Optional; + +import static com.facebook.presto.tests.statistics.MetricComparison.Result.DIFFER; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.MATCH; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_BASELINE; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_ESTIMATE; +import static java.lang.Math.abs; +import static java.lang.String.format; + +public class MetricComparison +{ + private final PlanNode planNode; + private final Metric metric; + private final Optional estimatedCost; + private final Optional executionCost; + private final double tolerance; + + public MetricComparison(PlanNode planNode, Metric metric, Optional estimatedCost, Optional executionCost, double tolerance) + { + this.planNode = planNode; + this.metric = metric; + this.estimatedCost = estimatedCost; + this.executionCost = executionCost; + this.tolerance = tolerance; + } + + public Metric getMetric() + { + return metric; + } + + public PlanNode getPlanNode() + { + return planNode; + } + + @Override + public String toString() + { + return format("Metric [%s] - [%s] - estimated: [%s], real: [%s] - plan node: [%s]", + metric, result(), print(estimatedCost), print(executionCost), planNode); + } + + public Result result() + { + return estimatedCost + .map(estimate -> executionCost + .map(execution -> estimateMatchesReality(estimate, execution) ? MATCH : DIFFER) + .orElse(NO_BASELINE)) + .orElse(NO_ESTIMATE); + } + + private String print(Optional cost) + { + return cost.map(Object::toString).orElse("UNKNOWN"); + } + + private boolean estimateMatchesReality(double estimate, double execution) + { + return abs(execution - estimate) / execution < tolerance; + } + + public enum Result + { + NO_ESTIMATE, + NO_BASELINE, + DIFFER, + MATCH + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/tpch/TpchIndexedData.java b/presto-tests/src/main/java/com/facebook/presto/tests/tpch/TpchIndexedData.java index bbac735b5d077..9ea75d9e5e74b 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/tpch/TpchIndexedData.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/tpch/TpchIndexedData.java @@ -72,7 +72,7 @@ public TpchIndexedData(String connectorId, TpchIndexSpec tpchIndexSpec) .collect(toImmutableSet()); TpchTable tpchTable = TpchTable.getTable(table.getTableName()); - RecordSet recordSet = tpchRecordSetProvider.getRecordSet(tpchTable, ImmutableList.copyOf(columnHandles.values()), table.getScaleFactor(), 0, 1); + RecordSet recordSet = tpchRecordSetProvider.getRecordSet(tpchTable, ImmutableList.copyOf(columnHandles.values()), table.getScaleFactor(), 0, 1, Optional.empty()); IndexedTable indexedTable = indexTable(recordSet, ImmutableList.copyOf(columnHandles.keySet()), keyColumnNames); indexedTablesBuilder.put(keyColumns, indexedTable); } diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/TestEventListener.java b/presto-tests/src/test/java/com/facebook/presto/execution/TestEventListener.java index 409c609961bbb..5de2acfd19e3d 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/TestEventListener.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/TestEventListener.java @@ -15,11 +15,11 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.TestEventListenerPlugin.TestingEventListenerPlugin; +import com.facebook.presto.resourceGroups.ResourceGroupManagerPlugin; import com.facebook.presto.spi.eventlistener.QueryCompletedEvent; import com.facebook.presto.spi.eventlistener.QueryCreatedEvent; import com.facebook.presto.spi.eventlistener.SplitCompletedEvent; import com.facebook.presto.testing.MaterializedResult; -import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.DistributedQueryRunner; import com.facebook.presto.tpch.TpchPlugin; import com.google.common.collect.ImmutableList; @@ -40,6 +40,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.stream.Collectors.toSet; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; @Test(singleThreaded = true) public class TestEventListener @@ -47,7 +48,7 @@ public class TestEventListener private static final int SPLITS_PER_NODE = 3; private final EventsBuilder generatedEvents = new EventsBuilder(); - private QueryRunner queryRunner; + private DistributedQueryRunner queryRunner; private Session session; @BeforeClass @@ -60,10 +61,18 @@ private void setUp() .setSchema("tiny") .setClientInfo("{\"clientVersion\":\"testVersion\"}") .build(); - queryRunner = new DistributedQueryRunner(session, 1); + queryRunner = new DistributedQueryRunner(session, 1, ImmutableMap.of("experimental.resource-groups-enabled", "true")); queryRunner.installPlugin(new TpchPlugin()); queryRunner.installPlugin(new TestingEventListenerPlugin(generatedEvents)); + queryRunner.installPlugin(new ResourceGroupManagerPlugin()); queryRunner.createCatalog("tpch", "tpch", ImmutableMap.of("tpch.splits-per-node", Integer.toString(SPLITS_PER_NODE))); + queryRunner.getCoordinator().getResourceGroupManager().get() + .setConfigurationManager("file", ImmutableMap.of("resource-groups.config-file", getResourceFilePath("resource_groups_config_simple.json"))); + } + + private String getResourceFilePath(String fileName) + { + return this.getClass().getClassLoader().getResource(fileName).getPath(); } @AfterClass(alwaysRun = true) @@ -72,14 +81,14 @@ private void tearDown() queryRunner.close(); } - private EventsBuilder generateEvents(@Language("SQL") String sql, int numEventsExpected) + private MaterializedResult runQueryAndWaitForEvents(@Language("SQL") String sql, int numEventsExpected) throws Exception { generatedEvents.initialize(numEventsExpected); - queryRunner.execute(session, sql); + MaterializedResult result = queryRunner.execute(session, sql); generatedEvents.waitForEvents(10); - return generatedEvents; + return result; } @Test @@ -87,21 +96,23 @@ public void testConstantQuery() throws Exception { // QueryCreated: 1, QueryCompleted: 1, Splits: 1 - EventsBuilder events = generateEvents("SELECT 1", 3); + runQueryAndWaitForEvents("SELECT 1", 3); - QueryCreatedEvent queryCreatedEvent = getOnlyElement(events.getQueryCreatedEvents()); + QueryCreatedEvent queryCreatedEvent = getOnlyElement(generatedEvents.getQueryCreatedEvents()); assertEquals(queryCreatedEvent.getContext().getServerVersion(), "testversion"); assertEquals(queryCreatedEvent.getContext().getServerAddress(), "127.0.0.1"); assertEquals(queryCreatedEvent.getContext().getEnvironment(), "testing"); assertEquals(queryCreatedEvent.getContext().getClientInfo().get(), "{\"clientVersion\":\"testVersion\"}"); assertEquals(queryCreatedEvent.getMetadata().getQuery(), "SELECT 1"); - QueryCompletedEvent queryCompletedEvent = getOnlyElement(events.getQueryCompletedEvents()); + QueryCompletedEvent queryCompletedEvent = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + assertTrue(queryCompletedEvent.getContext().getResourceGroupName().isPresent()); + assertEquals(queryCompletedEvent.getContext().getResourceGroupName().get(), "global.user-user"); assertEquals(queryCompletedEvent.getStatistics().getTotalRows(), 0L); assertEquals(queryCompletedEvent.getContext().getClientInfo().get(), "{\"clientVersion\":\"testVersion\"}"); assertEquals(queryCreatedEvent.getMetadata().getQueryId(), queryCompletedEvent.getMetadata().getQueryId()); - List splitCompletedEvents = events.getSplitCompletedEvents(); + List splitCompletedEvents = generatedEvents.getSplitCompletedEvents(); assertEquals(splitCompletedEvents.get(0).getQueryId(), queryCompletedEvent.getMetadata().getQueryId()); assertEquals(splitCompletedEvents.get(0).getStatistics().getCompletedPositions(), 1); } @@ -113,16 +124,18 @@ public void testNormalQuery() // We expect the following events // QueryCreated: 1, QueryCompleted: 1, Splits: SPLITS_PER_NODE (leaf splits) + LocalExchange[SINGLE] split + Aggregation/Output split int expectedEvents = 1 + 1 + SPLITS_PER_NODE + 1 + 1; - EventsBuilder events = generateEvents("SELECT sum(linenumber) FROM lineitem", expectedEvents); + runQueryAndWaitForEvents("SELECT sum(linenumber) FROM lineitem", expectedEvents); - QueryCreatedEvent queryCreatedEvent = getOnlyElement(events.getQueryCreatedEvents()); + QueryCreatedEvent queryCreatedEvent = getOnlyElement(generatedEvents.getQueryCreatedEvents()); assertEquals(queryCreatedEvent.getContext().getServerVersion(), "testversion"); assertEquals(queryCreatedEvent.getContext().getServerAddress(), "127.0.0.1"); assertEquals(queryCreatedEvent.getContext().getEnvironment(), "testing"); assertEquals(queryCreatedEvent.getContext().getClientInfo().get(), "{\"clientVersion\":\"testVersion\"}"); assertEquals(queryCreatedEvent.getMetadata().getQuery(), "SELECT sum(linenumber) FROM lineitem"); - QueryCompletedEvent queryCompletedEvent = getOnlyElement(events.getQueryCompletedEvents()); + QueryCompletedEvent queryCompletedEvent = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + assertTrue(queryCompletedEvent.getContext().getResourceGroupName().isPresent()); + assertEquals(queryCompletedEvent.getContext().getResourceGroupName().get(), "global.user-user"); assertEquals(queryCompletedEvent.getIoMetadata().getOutput(), Optional.empty()); assertEquals(queryCompletedEvent.getIoMetadata().getInputs().size(), 1); assertEquals(queryCompletedEvent.getContext().getClientInfo().get(), "{\"clientVersion\":\"testVersion\"}"); @@ -130,7 +143,7 @@ public void testNormalQuery() assertEquals(queryCreatedEvent.getMetadata().getQueryId(), queryCompletedEvent.getMetadata().getQueryId()); assertEquals(queryCompletedEvent.getStatistics().getCompletedSplits(), SPLITS_PER_NODE + 2); - List splitCompletedEvents = events.getSplitCompletedEvents(); + List splitCompletedEvents = generatedEvents.getSplitCompletedEvents(); assertEquals(splitCompletedEvents.size(), SPLITS_PER_NODE + 2); // leaf splits + aggregation split // All splits must have the same query ID @@ -145,7 +158,7 @@ public void testNormalQuery() .mapToLong(e -> e.getStatistics().getCompletedPositions()) .sum(); - MaterializedResult result = queryRunner.execute(session, "SELECT count(*) FROM lineitem"); + MaterializedResult result = runQueryAndWaitForEvents("SELECT count(*) FROM lineitem", expectedEvents); long expectedCompletedPositions = (long) result.getMaterializedRows().get(0).getField(0); assertEquals(actualCompletedPositions, expectedCompletedPositions); diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java b/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java index cc75dd709ac77..760dbf0a6df43 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/TestQueues.java @@ -34,6 +34,8 @@ import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static org.testng.Assert.assertEquals; +// run single threaded to avoid creating multiple query runners at once +@Test(singleThreaded = true) public class TestQueues { private static final String LONG_LASTING_QUERY = "SELECT COUNT(*) FROM lineitem"; diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroupIntegration.java b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroupIntegration.java index b5401638370fe..8756ec828295c 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroupIntegration.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroupIntegration.java @@ -33,18 +33,11 @@ public void testMemoryFraction() { try (DistributedQueryRunner queryRunner = createQueryRunner(ImmutableMap.of(), ImmutableMap.of("experimental.resource-groups-enabled", "true"))) { queryRunner.installPlugin(new ResourceGroupManagerPlugin()); - queryRunner.getCoordinator().getResourceGroupManager().get().setConfigurationManager("file", ImmutableMap.of("resource-groups.config-file", getResourceFilePath("resource_groups_memory_percentage.json"))); + getResourceGroupManager(queryRunner).setConfigurationManager("file", ImmutableMap.of( + "resource-groups.config-file", getResourceFilePath("resource_groups_memory_percentage.json"))); queryRunner.execute("SELECT COUNT(*), clerk FROM orders GROUP BY clerk"); - long startTime = System.nanoTime(); - while (true) { - SECONDS.sleep(1); - ResourceGroupInfo global = queryRunner.getCoordinator().getResourceGroupManager().get().getResourceGroupInfo(new ResourceGroupId("global")); - if (global.getSoftMemoryLimit().toBytes() > 0) { - break; - } - assertLessThan(nanosSince(startTime).roundTo(SECONDS), 60L); - } + waitForGlobalResourceGroup(queryRunner); } } @@ -52,4 +45,24 @@ private String getResourceFilePath(String fileName) { return this.getClass().getClassLoader().getResource(fileName).getPath(); } + + public static void waitForGlobalResourceGroup(DistributedQueryRunner queryRunner) + throws InterruptedException + { + long startTime = System.nanoTime(); + while (true) { + SECONDS.sleep(1); + ResourceGroupInfo global = getResourceGroupManager(queryRunner).getResourceGroupInfo(new ResourceGroupId("global")); + if (global.getSoftMemoryLimit().toBytes() > 0) { + break; + } + assertLessThan(nanosSince(startTime).roundTo(SECONDS), 60L); + } + } + + private static InternalResourceGroupManager getResourceGroupManager(DistributedQueryRunner queryRunner) + { + return queryRunner.getCoordinator().getResourceGroupManager() + .orElseThrow(() -> new IllegalArgumentException("no resource group manager")); + } } diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/H2TestUtil.java b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/H2TestUtil.java new file mode 100644 index 0000000000000..e6f4550b647e8 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/H2TestUtil.java @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.resourceGroups.db; + +import com.facebook.presto.Session; +import com.facebook.presto.execution.QueryManager; +import com.facebook.presto.execution.QueryState; +import com.facebook.presto.resourceGroups.db.DbResourceGroupConfig; +import com.facebook.presto.resourceGroups.db.H2DaoProvider; +import com.facebook.presto.resourceGroups.db.H2ResourceGroupsDao; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.resourceGroups.ResourceGroupSelector; +import com.facebook.presto.sql.parser.SqlParserOptions; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Random; +import java.util.Set; + +import static com.facebook.presto.execution.QueryState.RUNNING; +import static com.facebook.presto.execution.QueryState.TERMINAL_QUERY_STATES; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +class H2TestUtil +{ + private static final String CONFIGURATION_MANAGER_TYPE = "h2"; + + private H2TestUtil() {} + + public static Session adhocSession() + { + return testSessionBuilder() + .setCatalog("tpch") + .setSchema("sf100000") + .setSource("adhoc") + .build(); + } + + public static Session dashboardSession() + { + return testSessionBuilder() + .setCatalog("tpch") + .setSchema("sf100000") + .setSource("dashboard") + .build(); + } + + public static Session rejectingSession() + { + return testSessionBuilder() + .setCatalog("tpch") + .setSchema("sf100000") + .setSource("reject") + .build(); + } + + public static void waitForCompleteQueryCount(DistributedQueryRunner queryRunner, int expectedCount) + throws InterruptedException + { + waitForQueryCount(queryRunner, TERMINAL_QUERY_STATES, expectedCount); + } + + public static void waitForRunningQueryCount(DistributedQueryRunner queryRunner, int expectedCount) + throws InterruptedException + { + waitForQueryCount(queryRunner, ImmutableSet.of(RUNNING), expectedCount); + } + + public static void waitForQueryCount(DistributedQueryRunner queryRunner, Set countingStates, int expectedCount) + throws InterruptedException + { + QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); + while (queryManager.getAllQueryInfo().stream() + .filter(q -> countingStates.contains(q.getState())).count() != expectedCount) { + MILLISECONDS.sleep(500); + } + } + + public static String getDbConfigUrl() + { + return "jdbc:h2:mem:test_" + Math.abs(new Random().nextLong()); + } + + public static H2ResourceGroupsDao getDao(String url) + { + DbResourceGroupConfig dbResourceGroupConfig = new DbResourceGroupConfig() + .setConfigDbUrl(url); + H2ResourceGroupsDao dao = new H2DaoProvider(dbResourceGroupConfig).get(); + dao.createResourceGroupsTable(); + dao.createSelectorsTable(); + dao.createResourceGroupsGlobalPropertiesTable(); + return dao; + } + + public static DistributedQueryRunner createQueryRunner(String dbConfigUrl, H2ResourceGroupsDao dao) + throws Exception + { + DistributedQueryRunner queryRunner = new DistributedQueryRunner( + testSessionBuilder().setCatalog("tpch").setSchema("tiny").build(), + 2, + ImmutableMap.of("experimental.resource-groups-enabled", "true"), + ImmutableMap.of(), + new SqlParserOptions()); + try { + Plugin h2ResourceGroupManagerPlugin = new H2ResourceGroupManagerPlugin(); + queryRunner.installPlugin(h2ResourceGroupManagerPlugin); + queryRunner.getCoordinator().getResourceGroupManager().get() + .setConfigurationManager(CONFIGURATION_MANAGER_TYPE, ImmutableMap.of("resource-groups.config-db-url", dbConfigUrl)); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + setup(queryRunner, dao); + return queryRunner; + } + catch (Exception e) { + queryRunner.close(); + throw e; + } + } + + public static DistributedQueryRunner getSimpleQueryRunner() + throws Exception + { + String dbConfigUrl = getDbConfigUrl(); + H2ResourceGroupsDao dao = getDao(dbConfigUrl); + return createQueryRunner(dbConfigUrl, dao); + } + + private static void setup(DistributedQueryRunner queryRunner, H2ResourceGroupsDao dao) + throws InterruptedException + { + dao.insertResourceGroupsGlobalProperties("cpu_quota_period", "1h"); + dao.insertResourceGroup(1, "global", "1MB", 100, 1000, null, null, null, null, null, null, null, null); + dao.insertResourceGroup(2, "bi-${USER}", "1MB", 3, 2, null, null, null, null, null, null, null, 1L); + dao.insertResourceGroup(3, "user-${USER}", "1MB", 3, 3, null, null, null, null, null, null, null, 1L); + dao.insertResourceGroup(4, "adhoc-${USER}", "1MB", 3, 3, null, null, null, null, null, null, null, 3L); + dao.insertResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, null, null, 3L); + dao.insertSelector(2, "user.*", "test"); + dao.insertSelector(4, "user.*", "(?i).*adhoc.*"); + dao.insertSelector(5, "user.*", "(?i).*dashboard.*"); + // Selectors are loaded last + while (getSelectors(queryRunner).size() != 3) { + MILLISECONDS.sleep(500); + } + } + + public static List getSelectors(DistributedQueryRunner queryRunner) + { + return queryRunner.getCoordinator().getResourceGroupManager().get().getConfigurationManager().getSelectors(); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java index db8edf1a7d193..4226bf8eea81d 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestQueues.java @@ -13,46 +13,44 @@ */ package com.facebook.presto.execution.resourceGroups.db; -import com.facebook.presto.Session; import com.facebook.presto.execution.QueryManager; import com.facebook.presto.execution.QueryState; -import com.facebook.presto.execution.TestingSessionFactory; import com.facebook.presto.execution.resourceGroups.ResourceGroupManager; -import com.facebook.presto.resourceGroups.db.DbResourceGroupConfig; -import com.facebook.presto.resourceGroups.db.H2DaoProvider; import com.facebook.presto.resourceGroups.db.H2ResourceGroupsDao; -import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; -import com.facebook.presto.spi.resourceGroups.ResourceGroupSelector; -import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.tests.DistributedQueryRunner; -import com.facebook.presto.tests.tpch.TpchQueryRunner; -import com.facebook.presto.tpch.TpchPlugin; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.Set; import java.util.concurrent.TimeUnit; import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; -import static com.facebook.presto.execution.QueryState.TERMINAL_QUERY_STATES; +import static com.facebook.presto.execution.TestQueryRunnerUtil.cancelQuery; +import static com.facebook.presto.execution.TestQueryRunnerUtil.createQuery; +import static com.facebook.presto.execution.TestQueryRunnerUtil.waitForQueryState; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.adhocSession; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.createQueryRunner; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.dashboardSession; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getDao; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getDbConfigUrl; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getSelectors; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getSimpleQueryRunner; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.rejectingSession; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.waitForCompleteQueryCount; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.waitForRunningQueryCount; import static com.facebook.presto.spi.StandardErrorCode.QUERY_REJECTED; -import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.testng.Assert.assertEquals; +// run single threaded to avoid creating multiple query runners at once +@Test(singleThreaded = true) public class TestQueues { // Copy of TestQueues with tests for db reconfiguration of resource groups - private static final String NAME = "h2"; private static final String LONG_LASTING_QUERY = "SELECT COUNT(*) FROM lineitem"; @Test(timeOut = 60_000) @@ -71,7 +69,7 @@ public void testRunningQuery() } } - @Test(timeOut = 240_000) + @Test(timeOut = 60_000) public void testBasic() throws Exception { @@ -80,31 +78,31 @@ public void testBasic() try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); // submit first "dashboard" query - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); // wait for the first "dashboard" query to start waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); waitForRunningQueryCount(queryRunner, 1); // submit second "dashboard" query - QueryId secondDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); MILLISECONDS.sleep(2000); // wait for the second "dashboard" query to be queued ("dashboard.${USER}" queue strategy only allows one "dashboard" query to be accepted for execution) waitForQueryState(queryRunner, secondDashboardQuery, QUEUED); waitForRunningQueryCount(queryRunner, 1); // Update db to allow for 1 more running query in dashboard resource group - dao.updateResourceGroup(3, "user-${USER}", "1MB", 3, 4, null, null, null, null, null, 1L); - dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 2, null, null, null, null, null, 3L); + dao.updateResourceGroup(3, "user-${USER}", "1MB", 3, 4, null, null, null, null, null, null, null, 1L); + dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 2, null, null, null, null, null, null, null, 3L); waitForQueryState(queryRunner, secondDashboardQuery, RUNNING); - QueryId thirdDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId thirdDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, thirdDashboardQuery, QUEUED); waitForRunningQueryCount(queryRunner, 2); // submit first non "dashboard" query - QueryId firstNonDashboardQuery = createQuery(queryRunner, newSession(), LONG_LASTING_QUERY); + QueryId firstNonDashboardQuery = createQuery(queryRunner, adhocSession(), LONG_LASTING_QUERY); // wait for the first non "dashboard" query to start waitForQueryState(queryRunner, firstNonDashboardQuery, RUNNING); waitForRunningQueryCount(queryRunner, 3); // submit second non "dashboard" query - QueryId secondNonDashboardQuery = createQuery(queryRunner, newSession(), LONG_LASTING_QUERY); + QueryId secondNonDashboardQuery = createQuery(queryRunner, adhocSession(), LONG_LASTING_QUERY); // wait for the second non "dashboard" query to start waitForQueryState(queryRunner, secondNonDashboardQuery, RUNNING); waitForRunningQueryCount(queryRunner, 4); @@ -117,15 +115,15 @@ public void testBasic() } } - @Test(timeOut = 240_000) + @Test(timeOut = 60_000) public void testTwoQueriesAtSameTime() throws Exception { String dbConfigUrl = getDbConfigUrl(); H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); - QueryId secondDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); ImmutableSet queuedOrRunning = ImmutableSet.of(QUEUED, RUNNING); waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); @@ -133,31 +131,31 @@ public void testTwoQueriesAtSameTime() } } - @Test(timeOut = 240_000) + @Test(timeOut = 60_000) public void testTooManyQueries() throws Exception { String dbConfigUrl = getDbConfigUrl(); H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { - QueryId firstDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); - QueryId secondDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, secondDashboardQuery, QUEUED); - QueryId thirdDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + QueryId thirdDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, thirdDashboardQuery, FAILED); // Allow one more query to run and resubmit third query - dao.updateResourceGroup(3, "user-${USER}", "1MB", 3, 4, null, null, null, null, null, 1L); - dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 2, null, null, null, null, null, 3L); + dao.updateResourceGroup(3, "user-${USER}", "1MB", 3, 4, null, null, null, null, null, null, null, 1L); + dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 2, null, null, null, null, null, null, null, 3L); waitForQueryState(queryRunner, secondDashboardQuery, RUNNING); - thirdDashboardQuery = createQuery(queryRunner, newDashboardSession(), LONG_LASTING_QUERY); + thirdDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, thirdDashboardQuery, QUEUED); // Lower running queries in dashboard resource groups and wait until groups are reconfigured - dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, 3L); + dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, null, null, 3L); ResourceGroupManager manager = queryRunner.getCoordinator().getResourceGroupManager().get(); while (manager.getResourceGroupInfo( new ResourceGroupId(new ResourceGroupId(new ResourceGroupId("global"), "user-user"), "dashboard-user")).getMaxRunningQueries() != 1) { @@ -171,7 +169,7 @@ public void testTooManyQueries() } } - @Test(timeOut = 240_000) + @Test(timeOut = 60_000) public void testRejection() throws Exception { @@ -179,7 +177,7 @@ public void testRejection() H2ResourceGroupsDao dao = getDao(dbConfigUrl); try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { // Verify the query cannot be submitted - QueryId queryId = createQuery(queryRunner, newRejectionSession(), LONG_LASTING_QUERY); + QueryId queryId = createQuery(queryRunner, rejectingSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, queryId, FAILED); QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); assertEquals(queryManager.getQueryInfo(queryId).getErrorCode(), QUERY_REJECTED.toErrorCode()); @@ -190,168 +188,44 @@ public void testRejection() MILLISECONDS.sleep(500); } // Verify the query can be submitted - queryId = createQuery(queryRunner, newRejectionSession(), LONG_LASTING_QUERY); + queryId = createQuery(queryRunner, rejectingSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, queryId, RUNNING); dao.deleteSelector(4, "user.*", "(?i).*reject.*"); while (getSelectors(queryRunner).size() != selectorCount) { MILLISECONDS.sleep(500); } // Verify the query cannot be submitted - queryId = createQuery(queryRunner, newRejectionSession(), LONG_LASTING_QUERY); + queryId = createQuery(queryRunner, rejectingSession(), LONG_LASTING_QUERY); waitForQueryState(queryRunner, queryId, FAILED); } } - private static Session newSession() - { - return testSessionBuilder() - .setCatalog("tpch") - .setSchema("sf100000") - .setSource("adhoc") - .build(); - } - - private static Session newDashboardSession() - { - return testSessionBuilder() - .setCatalog("tpch") - .setSchema("sf100000") - .setSource("dashboard") - .build(); - } - - private static Session newRejectionSession() - { - return testSessionBuilder() - .setCatalog("tpch") - .setSchema("sf100000") - .setSource("reject") - .build(); - } - - private static QueryId createQuery(DistributedQueryRunner queryRunner, Session session, String sql) - { - return queryRunner.getCoordinator().getQueryManager().createQuery(new TestingSessionFactory(session), sql).getQueryId(); - } - - private static void cancelQuery(DistributedQueryRunner queryRunner, QueryId queryId) - { - queryRunner.getCoordinator().getQueryManager().cancelQuery(queryId); - } - - private static void waitForCompleteQueryCount(DistributedQueryRunner queryRunner, int expectedCount) - throws InterruptedException - { - waitForQueryCount(queryRunner, TERMINAL_QUERY_STATES, expectedCount); - } - - private static void waitForRunningQueryCount(DistributedQueryRunner queryRunner, int expectedCount) - throws InterruptedException - { - waitForQueryCount(queryRunner, ImmutableSet.of(RUNNING), expectedCount); - } - - private static void waitForQueryCount(DistributedQueryRunner queryRunner, Set countingStates, int expectedCount) - throws InterruptedException - { - QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); - while (queryManager.getAllQueryInfo().stream().filter(q -> countingStates.contains(q.getState())).count() != expectedCount) { - MILLISECONDS.sleep(500); - } - } - - private static void waitForQueryState(DistributedQueryRunner queryRunner, QueryId queryId, QueryState expectedQueryState) - throws InterruptedException - { - waitForQueryState(queryRunner, queryId, ImmutableSet.of(expectedQueryState)); - } - - private static void waitForQueryState(DistributedQueryRunner queryRunner, QueryId queryId, Set expectedQueryStates) - throws InterruptedException - { - QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); - while (!expectedQueryStates.contains(queryManager.getQueryInfo(queryId).getState())) { - MILLISECONDS.sleep(500); - } - } - - private static String getDbConfigUrl() - { - Random rnd = new Random(); - return "jdbc:h2:mem:test_" + Math.abs(rnd.nextLong()); - } - - private static H2ResourceGroupsDao getDao(String url) - { - DbResourceGroupConfig dbResourceGroupConfig = new DbResourceGroupConfig() - .setConfigDbUrl(url); - H2ResourceGroupsDao dao = new H2DaoProvider(dbResourceGroupConfig).get(); - dao.createResourceGroupsTable(); - dao.createSelectorsTable(); - dao.createResourceGroupsGlobalPropertiesTable(); - return dao; - } - - private static DistributedQueryRunner createQueryRunner(String dbConfigUrl, H2ResourceGroupsDao dao) + @Test(timeOut = 60_000) + public void testRunningTimeLimit() throws Exception { - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put("experimental.resource-groups-enabled", "true"); - Map properties = builder.build(); - DistributedQueryRunner queryRunner = new DistributedQueryRunner(testSessionBuilder().build(), 2, ImmutableMap.of(), properties, new SqlParserOptions()); - try { - Plugin h2ResourceGroupManagerPlugin = new H2ResourceGroupManagerPlugin(); - queryRunner.installPlugin(h2ResourceGroupManagerPlugin); - queryRunner.getCoordinator().getResourceGroupManager().get() - .setConfigurationManager(NAME, ImmutableMap.of("resource-groups.config-db-url", dbConfigUrl)); - queryRunner.installPlugin(new TpchPlugin()); - queryRunner.createCatalog("tpch", "tpch"); - setup(queryRunner, dao); - return queryRunner; - } - catch (Exception e) { - queryRunner.close(); - throw e; + String dbConfigUrl = getDbConfigUrl(); + H2ResourceGroupsDao dao = getDao(dbConfigUrl); + try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { + dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, null, "3s", 3L); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); + waitForQueryState(queryRunner, firstDashboardQuery, FAILED); } } - static DistributedQueryRunner getSimpleQueryRunner() + @Test(timeOut = 60_000) + public void testQueuedTimeLimit() throws Exception { String dbConfigUrl = getDbConfigUrl(); H2ResourceGroupsDao dao = getDao(dbConfigUrl); - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put("experimental.resource-groups-enabled", "true"); - Map properties = builder.build(); - DistributedQueryRunner queryRunner = TpchQueryRunner.createQueryRunner(properties); - Plugin h2ResourceGroupManagerPlugin = new H2ResourceGroupManagerPlugin(); - queryRunner.installPlugin(h2ResourceGroupManagerPlugin); - queryRunner.getCoordinator().getResourceGroupManager().get() - .setConfigurationManager(NAME, ImmutableMap.of("resource-groups.config-db-url", dbConfigUrl)); - setup(queryRunner, dao); - return queryRunner; - } - - private static void setup(DistributedQueryRunner queryRunner, H2ResourceGroupsDao dao) - throws InterruptedException - { - dao.insertResourceGroupsGlobalProperties("cpu_quota_period", "1h"); - dao.insertResourceGroup(1, "global", "1MB", 100, 1000, null, null, null, null, null, null); - dao.insertResourceGroup(2, "bi-${USER}", "1MB", 3, 2, null, null, null, null, null, 1L); - dao.insertResourceGroup(3, "user-${USER}", "1MB", 3, 3, null, null, null, null, null, 1L); - dao.insertResourceGroup(4, "adhoc-${USER}", "1MB", 3, 3, null, null, null, null, null, 3L); - dao.insertResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, 3L); - dao.insertSelector(2, "user.*", "test"); - dao.insertSelector(4, "user.*", "(?i).*adhoc.*"); - dao.insertSelector(5, "user.*", "(?i).*dashboard.*"); - // Selectors are loaded last - while (getSelectors(queryRunner).size() != 3) { - MILLISECONDS.sleep(500); + try (DistributedQueryRunner queryRunner = createQueryRunner(dbConfigUrl, dao)) { + dao.updateResourceGroup(5, "dashboard-${USER}", "1MB", 1, 1, null, null, null, null, null, "5s", null, 3L); + QueryId firstDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); + waitForQueryState(queryRunner, firstDashboardQuery, RUNNING); + QueryId secondDashboardQuery = createQuery(queryRunner, dashboardSession(), LONG_LASTING_QUERY); + waitForQueryState(queryRunner, secondDashboardQuery, QUEUED); + waitForQueryState(queryRunner, secondDashboardQuery, FAILED); } } - - private static List getSelectors(DistributedQueryRunner queryRunner) - { - return queryRunner.getCoordinator().getResourceGroupManager().get().getConfigurationManager().getSelectors(); - } } diff --git a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java index 37bd72d50224a..be44580e89123 100644 --- a/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java +++ b/presto-tests/src/test/java/com/facebook/presto/execution/resourceGroups/db/TestResourceGroupIntegration.java @@ -13,30 +13,21 @@ */ package com.facebook.presto.execution.resourceGroups.db; -import com.facebook.presto.spi.resourceGroups.ResourceGroupId; -import com.facebook.presto.spi.resourceGroups.ResourceGroupInfo; import com.facebook.presto.tests.DistributedQueryRunner; import org.testng.annotations.Test; -import java.util.concurrent.TimeUnit; - -import static com.facebook.presto.execution.resourceGroups.db.TestQueues.getSimpleQueryRunner; +import static com.facebook.presto.execution.resourceGroups.TestResourceGroupIntegration.waitForGlobalResourceGroup; +import static com.facebook.presto.execution.resourceGroups.db.H2TestUtil.getSimpleQueryRunner; public class TestResourceGroupIntegration { - @Test(timeOut = 60_000) + @Test public void testMemoryFraction() throws Exception { try (DistributedQueryRunner queryRunner = getSimpleQueryRunner()) { queryRunner.execute("SELECT COUNT(*), clerk FROM orders GROUP BY clerk"); - while (true) { - TimeUnit.SECONDS.sleep(1); - ResourceGroupInfo global = queryRunner.getCoordinator().getResourceGroupManager().get().getResourceGroupInfo(new ResourceGroupId("global")); - if (global.getSoftMemoryLimit().toBytes() > 0) { - break; - } - } + waitForGlobalResourceGroup(queryRunner); } } } diff --git a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java index a968fae22755b..b63dfce3044a0 100644 --- a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java +++ b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java @@ -52,6 +52,7 @@ import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; +// run single threaded to avoid creating multiple query runners at once @Test(singleThreaded = true) public class TestMemoryManager { @@ -123,6 +124,7 @@ public void testOutOfMemoryKiller() while (!queryDone) { for (QueryInfo info : queryRunner.getCoordinator().getQueryManager().getAllQueryInfo()) { if (info.getState().isDone()) { + assertNotNull(info.getErrorCode()); assertEquals(info.getErrorCode().getCode(), CLUSTER_OUT_OF_MEMORY.toErrorCode().getCode()); queryDone = true; break; diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java index 743fb12ba69f6..b5c5964e1f3dc 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java @@ -17,12 +17,18 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.tpch.TpchConnectorFactory; import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; -import static com.facebook.presto.SystemSessionProperties.REORDER_JOINS; +import static com.facebook.presto.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingSession.TESTING_CATALOG; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; public class TestLocalQueries @@ -38,7 +44,7 @@ public static LocalQueryRunner createLocalQueryRunner() Session defaultSession = testSessionBuilder() .setCatalog("local") .setSchema(TINY_SCHEMA_NAME) - .setSystemProperty(REORDER_JOINS, "true") + .setSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "true") .build(); LocalQueryRunner localQueryRunner = new LocalQueryRunner(defaultSession); @@ -58,4 +64,18 @@ public static LocalQueryRunner createLocalQueryRunner() return localQueryRunner; } + + @Test + public void testShowColumnStats() + throws Exception + { + // FIXME Add tests for more complex scenario with more stats + MaterializedResult result = computeActual("SHOW STATS FOR nation"); + + MaterializedResult expectedStatistics = resultBuilder(getSession(), VARCHAR, DOUBLE) + .row(null, 25.0) + .build(); + + assertEquals(result, expectedStatistics); + } } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestMetadataManager.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestMetadataManager.java index bf199dc27d467..82cfece57c4e7 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestMetadataManager.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestMetadataManager.java @@ -13,11 +13,17 @@ */ package com.facebook.presto.tests; +import com.facebook.presto.execution.QueryManager; +import com.facebook.presto.execution.TestingSessionFactory; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.spi.QueryId; import org.intellij.lang.annotations.Language; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.execution.QueryState.RUNNING; import static com.facebook.presto.tests.tpch.TpchQueryRunner.createQueryRunner; import static org.testng.Assert.assertEquals; @@ -30,16 +36,23 @@ @Test(singleThreaded = true) public class TestMetadataManager { - private final QueryRunner queryRunner; - private final MetadataManager metadataManager; + private DistributedQueryRunner queryRunner; + private MetadataManager metadataManager; - TestMetadataManager() + @BeforeClass + public void setUp() throws Exception { queryRunner = createQueryRunner(); metadataManager = (MetadataManager) queryRunner.getMetadata(); } + @AfterClass(alwaysRun = true) + public void tearDown() + { + queryRunner.close(); + } + @Test public void testMetadataIsClearedAfterQueryFinished() { @@ -62,4 +75,25 @@ public void testMetadataIsClearedAfterQueryFailed() assertEquals(metadataManager.getCatalogsByQueryId().size(), 0); } + + @Test + public void testMetadataIsClearedAfterQueryCanceled() + throws Exception + { + QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); + QueryId queryId = queryManager.createQuery(new TestingSessionFactory(TEST_SESSION), + "SELECT * FROM lineitem").getQueryId(); + + // wait until query starts running + while (true) { + if (queryManager.getQueryInfo(queryId).getState() == RUNNING) { + break; + } + Thread.sleep(100); + } + + // cancel query + queryManager.cancelQuery(queryId); + assertEquals(metadataManager.getCatalogsByQueryId().size(), 0); + } } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestQuerySpillLimits.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestQuerySpillLimits.java index 158cf2283dfc0..a65f167b9fe94 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestQuerySpillLimits.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestQuerySpillLimits.java @@ -28,6 +28,7 @@ import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +@Test(singleThreaded = true) public class TestQuerySpillLimits { private static final Session SESSION = testSessionBuilder() @@ -44,7 +45,7 @@ public void setUp() this.spillPath = Files.createTempDir(); } - @AfterMethod + @AfterMethod(alwaysRun = true) public void tearDown() throws Exception { diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchDistributedStats.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchDistributedStats.java new file mode 100644 index 0000000000000..de6b72355c556 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestTpchDistributedStats.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.execution.StageInfo; +import com.facebook.presto.spi.QueryId; +import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.tests.statistics.Metric; +import com.facebook.presto.tests.statistics.MetricComparator; +import com.facebook.presto.tests.statistics.MetricComparison; +import com.facebook.presto.tpch.ColumnNaming; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.DIFFER; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.MATCH; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_BASELINE; +import static com.facebook.presto.tests.statistics.MetricComparison.Result.NO_ESTIMATE; +import static com.facebook.presto.tests.tpch.TpchQueryRunner.createQueryRunnerWithoutCatalogs; +import static java.lang.String.format; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.stream.Collectors.groupingBy; +import static org.testng.Assert.assertEquals; + +public class TestTpchDistributedStats +{ + public static final int NUMBER_OF_TPCH_QUERIES = 22; + + DistributedQueryRunner runner; + + public TestTpchDistributedStats() + throws Exception + { + runner = createQueryRunnerWithoutCatalogs(emptyMap(), emptyMap()); + runner.createCatalog("tpch", "tpch", ImmutableMap.of( + "tpch.column-naming", ColumnNaming.STANDARD.name() + )); + } + + @Test + void testEstimateForSimpleQuery() + { + String queryId = executeQuery("SELECT * FROM NATION"); + + Plan queryPlan = getQueryPlan(queryId); + + MetricComparison rootOutputRowCountComparison = getRootOutputRowCountComparison(queryId, queryPlan); + assertEquals(rootOutputRowCountComparison.result(), MATCH); + } + + private MetricComparison getRootOutputRowCountComparison(String queryId, Plan queryPlan) + { + List comparisons = new MetricComparator().getMetricComparisons(queryPlan, getOutputStageInfo(queryId)); + return comparisons.stream() + .filter(comparison -> comparison.getMetric().equals(Metric.OUTPUT_ROW_COUNT)) + .filter(comparison -> comparison.getPlanNode().equals(queryPlan.getRoot())) + .findFirst() + .orElseThrow(() -> new AssertionError("No comparison for root node found")); + } + + /** + * This is a development tool for manual inspection of differences between + * cost estimates and actual execution costs. Its outputs need to be inspected + * manually because at this point no sensible assertions can be formulated + * for the entirety of TPCH queries. + */ + @Test(enabled = false) + void testCostEstimatesVsRealityDifferences() + { + IntStream.rangeClosed(1, NUMBER_OF_TPCH_QUERIES) + .filter(i -> i != 15) //query 15 creates a view, which TPCH connector does not support. + .forEach(i -> summarizeQuery(i, getTpchQuery(i))); + } + + private String getTpchQuery(int i) + { + try { + String queryClassPath = "/io/airlift/tpch/queries/q" + i + ".sql"; + return Resources.toString(getClass().getResource(queryClassPath), Charset.defaultCharset()); + } + catch (IOException e) { + throw Throwables.propagate(e); + } + } + + private Plan getQueryPlan(String queryId) + { + return runner.getQueryPlan(new QueryId(queryId)); + } + + private void summarizeQuery(int queryNumber, String query) + { + String queryId = executeQuery(query); + Plan queryPlan = getQueryPlan(queryId); + + List allPlanNodes = searchFrom(queryPlan.getRoot()).findAll(); + + System.out.println(format("Query TPCH [%s] produces [%s] plan nodes.\n", queryNumber, allPlanNodes.size())); + + List comparisons = new MetricComparator().getMetricComparisons(queryPlan, getOutputStageInfo(queryId)); + + Map>> metricSummaries = + comparisons.stream() + .collect(groupingBy(MetricComparison::getMetric, groupingBy(MetricComparison::result))); + + metricSummaries.forEach((metricName, resultSummaries) -> { + System.out.println(format("Summary for metric [%s]", metricName)); + outputSummary(resultSummaries, NO_ESTIMATE); + outputSummary(resultSummaries, NO_BASELINE); + outputSummary(resultSummaries, DIFFER); + outputSummary(resultSummaries, MATCH); + System.out.println(); + }); + + System.out.println("Detailed results:\n"); + + comparisons.forEach(System.out::println); + } + + private String executeQuery(String query) + { + return runner.executeWithQueryId(runner.getDefaultSession(), query).getQueryId(); + } + + private StageInfo getOutputStageInfo(String queryId) + { + return runner.getQueryInfo(new QueryId(queryId)).getOutputStage().get(); + } + + private void outputSummary(Map> resultSummaries, MetricComparison.Result result) + { + System.out.println(format("[%s]\t-\t[%s]", result, resultSummaries.getOrDefault(result, emptyList()).size())); + } +} diff --git a/presto-tests/src/test/resources/resource_groups_config_simple.json b/presto-tests/src/test/resources/resource_groups_config_simple.json new file mode 100644 index 0000000000000..6a1a65dbc55f2 --- /dev/null +++ b/presto-tests/src/test/resources/resource_groups_config_simple.json @@ -0,0 +1,23 @@ +{ + "rootGroups": [ + { + "name": "global", + "softMemoryLimit": "1MB", + "maxRunning": 100, + "maxQueued": 1000, + "subGroups": [ + { + "name": "user-${USER}", + "softMemoryLimit": "1MB", + "maxRunning": 3, + "maxQueued": 3 + } + ] + } + ], + "selectors": [ + { + "group": "global.user-${USER}" + } + ] +} diff --git a/presto-thrift-connector-api/pom.xml b/presto-thrift-connector-api/pom.xml new file mode 100644 index 0000000000000..f120ab42a0028 --- /dev/null +++ b/presto-thrift-connector-api/pom.xml @@ -0,0 +1,84 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.180-SNAPSHOT + + + presto-thrift-connector-api + Presto - Thrift Connector API + jar + + + ${project.parent.basedir} + + + + + com.google.guava + guava + + + + com.google.code.findbugs + annotations + + + + com.facebook.swift + swift-annotations + + + + com.facebook.presto + presto-spi + + + + io.airlift + slice + + + + com.fasterxml.jackson.core + jackson-annotations + + + + + + + com.facebook.swift + swift-javadoc + provided + + + + com.facebook.swift + swift-codec + provided + + + + + org.testng + testng + test + + + + com.facebook.presto + presto-main + test + + + + io.airlift + stats + test + + + diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/NameValidationUtils.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/NameValidationUtils.java new file mode 100644 index 0000000000000..92c1abd3c7634 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/NameValidationUtils.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; + +final class NameValidationUtils +{ + private NameValidationUtils() {} + + public static String checkValidName(String name) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + checkArgument('a' <= name.charAt(0) && name.charAt(0) <= 'z', "name must start with a lowercase latin letter: '%s'", name); + for (int i = 1; i < name.length(); i++) { + char ch = name.charAt(i); + checkArgument('a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' || ch == '_', + "name must contain only lowercase latin letters, digits or underscores: '%s'", name); + } + return name; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftBlock.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftBlock.java new file mode 100644 index 0000000000000..33838b84fbfee --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftBlock.java @@ -0,0 +1,311 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigint; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigintArray; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBoolean; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftColumnData; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftDate; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftDouble; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftHyperLogLog; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftInteger; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftJson; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTimestamp; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftVarchar; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.BigintType; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.spi.type.StandardTypes.ARRAY; +import static com.facebook.presto.spi.type.StandardTypes.BIGINT; +import static com.facebook.presto.spi.type.StandardTypes.BOOLEAN; +import static com.facebook.presto.spi.type.StandardTypes.DATE; +import static com.facebook.presto.spi.type.StandardTypes.DOUBLE; +import static com.facebook.presto.spi.type.StandardTypes.HYPER_LOG_LOG; +import static com.facebook.presto.spi.type.StandardTypes.INTEGER; +import static com.facebook.presto.spi.type.StandardTypes.JSON; +import static com.facebook.presto.spi.type.StandardTypes.TIMESTAMP; +import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterables.getOnlyElement; + +@ThriftStruct +public final class PrestoThriftBlock +{ + // number + private final PrestoThriftInteger integerData; + private final PrestoThriftBigint bigintData; + private final PrestoThriftDouble doubleData; + + // variable width + private final PrestoThriftVarchar varcharData; + + // boolean + private final PrestoThriftBoolean booleanData; + + // temporal + private final PrestoThriftDate dateData; + private final PrestoThriftTimestamp timestampData; + + // special + private final PrestoThriftJson jsonData; + private final PrestoThriftHyperLogLog hyperLogLogData; + + // array + private final PrestoThriftBigintArray bigintArrayData; + + // non-thrift field which points to non-null data item + private final PrestoThriftColumnData dataReference; + + @ThriftConstructor + public PrestoThriftBlock( + @Nullable PrestoThriftInteger integerData, + @Nullable PrestoThriftBigint bigintData, + @Nullable PrestoThriftDouble doubleData, + @Nullable PrestoThriftVarchar varcharData, + @Nullable PrestoThriftBoolean booleanData, + @Nullable PrestoThriftDate dateData, + @Nullable PrestoThriftTimestamp timestampData, + @Nullable PrestoThriftJson jsonData, + @Nullable PrestoThriftHyperLogLog hyperLogLogData, + @Nullable PrestoThriftBigintArray bigintArrayData) + { + this.integerData = integerData; + this.bigintData = bigintData; + this.doubleData = doubleData; + this.varcharData = varcharData; + this.booleanData = booleanData; + this.dateData = dateData; + this.timestampData = timestampData; + this.jsonData = jsonData; + this.hyperLogLogData = hyperLogLogData; + this.bigintArrayData = bigintArrayData; + this.dataReference = theOnlyNonNull(integerData, bigintData, doubleData, varcharData, booleanData, dateData, timestampData, jsonData, hyperLogLogData, bigintArrayData); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftInteger getIntegerData() + { + return integerData; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public PrestoThriftBigint getBigintData() + { + return bigintData; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftDouble getDoubleData() + { + return doubleData; + } + + @Nullable + @ThriftField(value = 4, requiredness = OPTIONAL) + public PrestoThriftVarchar getVarcharData() + { + return varcharData; + } + + @Nullable + @ThriftField(value = 5, requiredness = OPTIONAL) + public PrestoThriftBoolean getBooleanData() + { + return booleanData; + } + + @Nullable + @ThriftField(value = 6, requiredness = OPTIONAL) + public PrestoThriftDate getDateData() + { + return dateData; + } + + @Nullable + @ThriftField(value = 7, requiredness = OPTIONAL) + public PrestoThriftTimestamp getTimestampData() + { + return timestampData; + } + + @Nullable + @ThriftField(value = 8, requiredness = OPTIONAL) + public PrestoThriftJson getJsonData() + { + return jsonData; + } + + @Nullable + @ThriftField(value = 9, requiredness = OPTIONAL) + public PrestoThriftHyperLogLog getHyperLogLogData() + { + return hyperLogLogData; + } + + @Nullable + @ThriftField(value = 10, requiredness = OPTIONAL) + public PrestoThriftBigintArray getBigintArrayData() + { + return bigintArrayData; + } + + public Block toBlock(Type desiredType) + { + return dataReference.toBlock(desiredType); + } + + public int numberOfRecords() + { + return dataReference.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBlock other = (PrestoThriftBlock) obj; + // remaining fields are guaranteed to be null by the constructor + return Objects.equals(this.dataReference, other.dataReference); + } + + @Override + public int hashCode() + { + return Objects.hash(integerData, bigintData, doubleData, varcharData, booleanData, dateData, timestampData, jsonData, hyperLogLogData, bigintArrayData); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("data", dataReference) + .toString(); + } + + public static PrestoThriftBlock integerData(PrestoThriftInteger integerData) + { + return new PrestoThriftBlock(integerData, null, null, null, null, null, null, null, null, null); + } + + public static PrestoThriftBlock bigintData(PrestoThriftBigint bigintData) + { + return new PrestoThriftBlock(null, bigintData, null, null, null, null, null, null, null, null); + } + + public static PrestoThriftBlock doubleData(PrestoThriftDouble doubleData) + { + return new PrestoThriftBlock(null, null, doubleData, null, null, null, null, null, null, null); + } + + public static PrestoThriftBlock varcharData(PrestoThriftVarchar varcharData) + { + return new PrestoThriftBlock(null, null, null, varcharData, null, null, null, null, null, null); + } + + public static PrestoThriftBlock booleanData(PrestoThriftBoolean booleanData) + { + return new PrestoThriftBlock(null, null, null, null, booleanData, null, null, null, null, null); + } + + public static PrestoThriftBlock dateData(PrestoThriftDate dateData) + { + return new PrestoThriftBlock(null, null, null, null, null, dateData, null, null, null, null); + } + + public static PrestoThriftBlock timestampData(PrestoThriftTimestamp timestampData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, timestampData, null, null, null); + } + + public static PrestoThriftBlock jsonData(PrestoThriftJson jsonData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, null, jsonData, null, null); + } + + public static PrestoThriftBlock hyperLogLogData(PrestoThriftHyperLogLog hyperLogLogData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, null, null, hyperLogLogData, null); + } + + public static PrestoThriftBlock bigintArrayData(PrestoThriftBigintArray bigintArrayData) + { + return new PrestoThriftBlock(null, null, null, null, null, null, null, null, null, bigintArrayData); + } + + public static PrestoThriftBlock fromBlock(Block block, Type type) + { + switch (type.getTypeSignature().getBase()) { + case INTEGER: + return PrestoThriftInteger.fromBlock(block); + case BIGINT: + return PrestoThriftBigint.fromBlock(block); + case DOUBLE: + return PrestoThriftDouble.fromBlock(block); + case VARCHAR: + return PrestoThriftVarchar.fromBlock(block, type); + case BOOLEAN: + return PrestoThriftBoolean.fromBlock(block); + case DATE: + return PrestoThriftDate.fromBlock(block); + case TIMESTAMP: + return PrestoThriftTimestamp.fromBlock(block); + case JSON: + return PrestoThriftJson.fromBlock(block, type); + case HYPER_LOG_LOG: + return PrestoThriftHyperLogLog.fromBlock(block); + case ARRAY: + Type elementType = getOnlyElement(type.getTypeParameters()); + if (BigintType.BIGINT.equals(elementType)) { + return PrestoThriftBigintArray.fromBlock(block); + } + else { + throw new IllegalArgumentException("Unsupported array block type: " + type); + } + default: + throw new IllegalArgumentException("Unsupported block type: " + type); + } + } + + private static PrestoThriftColumnData theOnlyNonNull(PrestoThriftColumnData... columnsData) + { + PrestoThriftColumnData result = null; + for (PrestoThriftColumnData data : columnsData) { + if (data != null) { + checkArgument(result == null, "more than one type is present"); + result = data; + } + } + checkArgument(result != null, "no types are present"); + return result; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftColumnMetadata.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftColumnMetadata.java new file mode 100644 index 0000000000000..16bba53e46162 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftColumnMetadata.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftColumnMetadata +{ + private final String name; + private final String type; + private final String comment; + private final boolean hidden; + + @ThriftConstructor + public PrestoThriftColumnMetadata(String name, String type, @Nullable String comment, boolean hidden) + { + this.name = checkValidName(name); + this.type = requireNonNull(type, "type is null"); + this.comment = comment; + this.hidden = hidden; + } + + @ThriftField(1) + public String getName() + { + return name; + } + + @ThriftField(2) + public String getType() + { + return type; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public String getComment() + { + return comment; + } + + @ThriftField(4) + public boolean isHidden() + { + return hidden; + } + + public ColumnMetadata toColumnMetadata(TypeManager typeManager) + { + return new ColumnMetadata( + name, + typeManager.getType(parseTypeSignature(type)), + comment, + hidden); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftColumnMetadata other = (PrestoThriftColumnMetadata) obj; + return Objects.equals(this.name, other.name) && + Objects.equals(this.type, other.type) && + Objects.equals(this.comment, other.comment) && + this.hidden == other.hidden; + } + + @Override + public int hashCode() + { + return Objects.hash(name, type, comment, hidden); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("type", type) + .add("comment", comment) + .add("hidden", hidden) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftDomain.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftDomain.java new file mode 100644 index 0000000000000..a4eb6eb6fc5f3 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftDomain.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet; +import com.facebook.presto.spi.predicate.Domain; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftDomain +{ + private final PrestoThriftValueSet valueSet; + private final boolean nullAllowed; + + @ThriftConstructor + public PrestoThriftDomain(PrestoThriftValueSet valueSet, boolean nullAllowed) + { + this.valueSet = requireNonNull(valueSet, "valueSet is null"); + this.nullAllowed = nullAllowed; + } + + @ThriftField(1) + public PrestoThriftValueSet getValueSet() + { + return valueSet; + } + + @ThriftField(2) + public boolean isNullAllowed() + { + return nullAllowed; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftDomain other = (PrestoThriftDomain) obj; + return Objects.equals(this.valueSet, other.valueSet) && + this.nullAllowed == other.nullAllowed; + } + + @Override + public int hashCode() + { + return Objects.hash(valueSet, nullAllowed); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("valueSet", valueSet) + .add("nullAllowed", nullAllowed) + .toString(); + } + + public static PrestoThriftDomain fromDomain(Domain domain) + { + return new PrestoThriftDomain(fromValueSet(domain.getValues()), domain.isNullAllowed()); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftHostAddress.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftHostAddress.java new file mode 100644 index 0000000000000..d3317fdf47184 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftHostAddress.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.HostAddress; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftHostAddress +{ + private final String host; + private final int port; + + @ThriftConstructor + public PrestoThriftHostAddress(String host, int port) + { + this.host = requireNonNull(host, "host is null"); + this.port = port; + } + + @ThriftField(1) + public String getHost() + { + return host; + } + + @ThriftField(2) + public int getPort() + { + return port; + } + + public HostAddress toHostAddress() + { + return HostAddress.fromParts(getHost(), getPort()); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftHostAddress other = (PrestoThriftHostAddress) obj; + return Objects.equals(this.host, other.host) && + this.port == other.port; + } + + @Override + public int hashCode() + { + return Objects.hash(host, port); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("host", host) + .add("port", port) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftId.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftId.java new file mode 100644 index 0000000000000..082c960fb8f84 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftId.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.BaseEncoding; + +import java.util.Arrays; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftId +{ + private static final int PREFIX_SUFFIX_BYTES = 8; + private static final String FILLER = ".."; + private static final int MAX_DISPLAY_CHARACTERS = PREFIX_SUFFIX_BYTES * 4 + FILLER.length(); + + private final byte[] id; + + @JsonCreator + @ThriftConstructor + public PrestoThriftId(@JsonProperty("id") byte[] id) + { + this.id = requireNonNull(id, "id is null"); + } + + @JsonProperty + @ThriftField(1) + public byte[] getId() + { + return id; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftId other = (PrestoThriftId) obj; + return Arrays.equals(this.id, other.id); + } + + @Override + public int hashCode() + { + return Arrays.hashCode(id); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("id", summarize(id)) + .toString(); + } + + @VisibleForTesting + static String summarize(byte[] value) + { + if (value.length * 2 <= MAX_DISPLAY_CHARACTERS) { + return BaseEncoding.base16().encode(value); + } + return BaseEncoding.base16().encode(value, 0, PREFIX_SUFFIX_BYTES) + + FILLER + + BaseEncoding.base16().encode(value, value.length - PREFIX_SUFFIX_BYTES, PREFIX_SUFFIX_BYTES); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableColumnSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableColumnSet.java new file mode 100644 index 0000000000000..41b1e7f8b6d91 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableColumnSet.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; +import java.util.Set; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftNullableColumnSet +{ + private final Set columns; + + @ThriftConstructor + public PrestoThriftNullableColumnSet(@Nullable Set columns) + { + this.columns = columns; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public Set getColumns() + { + return columns; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftNullableColumnSet other = (PrestoThriftNullableColumnSet) obj; + return Objects.equals(this.columns, other.columns); + } + + @Override + public int hashCode() + { + return Objects.hashCode(columns); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columns", columns) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableSchemaName.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableSchemaName.java new file mode 100644 index 0000000000000..f048a9e92beae --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableSchemaName.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftNullableSchemaName +{ + private final String schemaName; + + @ThriftConstructor + public PrestoThriftNullableSchemaName(@Nullable String schemaName) + { + this.schemaName = schemaName; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public String getSchemaName() + { + return schemaName; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftNullableSchemaName other = (PrestoThriftNullableSchemaName) obj; + return Objects.equals(this.schemaName, other.schemaName); + } + + @Override + public int hashCode() + { + return Objects.hashCode(schemaName); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", schemaName) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableTableMetadata.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableTableMetadata.java new file mode 100644 index 0000000000000..a94cfcf401b57 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableTableMetadata.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftNullableTableMetadata +{ + private final PrestoThriftTableMetadata tableMetadata; + + @ThriftConstructor + public PrestoThriftNullableTableMetadata(@Nullable PrestoThriftTableMetadata tableMetadata) + { + this.tableMetadata = tableMetadata; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftTableMetadata getTableMetadata() + { + return tableMetadata; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftNullableTableMetadata other = (PrestoThriftNullableTableMetadata) obj; + return Objects.equals(this.tableMetadata, other.tableMetadata); + } + + @Override + public int hashCode() + { + return Objects.hashCode(tableMetadata); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("tableMetadata", tableMetadata) + .toString(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingColumnHandle.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableToken.java similarity index 55% rename from presto-main/src/test/java/com/facebook/presto/sql/planner/TestingColumnHandle.java rename to presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableToken.java index fb50987227270..b8670e5cb1520 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestingColumnHandle.java +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftNullableToken.java @@ -11,38 +11,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql.planner; +package com.facebook.presto.connector.thrift.api; -import com.facebook.presto.spi.ColumnHandle; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; import java.util.Objects; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; import static com.google.common.base.MoreObjects.toStringHelper; -import static java.util.Objects.requireNonNull; -public class TestingColumnHandle - implements ColumnHandle +@ThriftStruct +public final class PrestoThriftNullableToken { - private final String name; - - @JsonCreator - public TestingColumnHandle(@JsonProperty("name") String name) - { - this.name = requireNonNull(name, "name is null"); - } + private final PrestoThriftId token; - @JsonProperty - public String getName() + @ThriftConstructor + public PrestoThriftNullableToken(@Nullable PrestoThriftId token) { - return name; + this.token = token; } - @Override - public int hashCode() + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftId getToken() { - return Objects.hash(name); + return token; } @Override @@ -54,15 +51,21 @@ public boolean equals(Object obj) if (obj == null || getClass() != obj.getClass()) { return false; } - final TestingColumnHandle other = (TestingColumnHandle) obj; - return Objects.equals(this.name, other.name); + PrestoThriftNullableToken other = (PrestoThriftNullableToken) obj; + return Objects.equals(this.token, other.token); + } + + @Override + public int hashCode() + { + return Objects.hashCode(token); } @Override public String toString() { return toStringHelper(this) - .add("name", name) + .add("token", token) .toString(); } } diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftPageResult.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftPageResult.java new file mode 100644 index 0000000000000..71d046d3c37a1 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftPageResult.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftPageResult +{ + private final List columnBlocks; + private final int rowCount; + private final PrestoThriftId nextToken; + + @ThriftConstructor + public PrestoThriftPageResult(List columnBlocks, int rowCount, @Nullable PrestoThriftId nextToken) + { + this.columnBlocks = requireNonNull(columnBlocks, "columnBlocks is null"); + checkArgument(rowCount >= 0, "rowCount is negative"); + checkAllColumnsAreOfExpectedSize(columnBlocks, rowCount); + this.rowCount = rowCount; + this.nextToken = nextToken; + } + + /** + * Returns data in a columnar format. + * Columns in this list must be in the order they were requested by the engine. + */ + @ThriftField(1) + public List getColumnBlocks() + { + return columnBlocks; + } + + @ThriftField(2) + public int getRowCount() + { + return rowCount; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftId getNextToken() + { + return nextToken; + } + + @Nullable + public Page toPage(List columnTypes) + { + if (rowCount == 0) { + return null; + } + checkArgument(columnBlocks.size() == columnTypes.size(), "columns and types have different sizes"); + int numberOfColumns = columnBlocks.size(); + if (numberOfColumns == 0) { + // request/response with no columns, used for queries like "select count star" + return new Page(rowCount); + } + Block[] blocks = new Block[numberOfColumns]; + for (int i = 0; i < numberOfColumns; i++) { + blocks[i] = columnBlocks.get(i).toBlock(columnTypes.get(i)); + } + return new Page(blocks); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftPageResult other = (PrestoThriftPageResult) obj; + return Objects.equals(this.columnBlocks, other.columnBlocks) && + this.rowCount == other.rowCount && + Objects.equals(this.nextToken, other.nextToken); + } + + @Override + public int hashCode() + { + return Objects.hash(columnBlocks, rowCount, nextToken); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columnBlocks", columnBlocks) + .add("rowCount", rowCount) + .add("nextToken", nextToken) + .toString(); + } + + private static void checkAllColumnsAreOfExpectedSize(List columnBlocks, int expectedNumberOfRows) + { + for (int i = 0; i < columnBlocks.size(); i++) { + checkArgument(columnBlocks.get(i).numberOfRecords() == expectedNumberOfRows, + "Incorrect number of records for column with index %s: expected %s, got %s", + i, expectedNumberOfRows, columnBlocks.get(i).numberOfRecords()); + } + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSchemaTableName.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSchemaTableName.java new file mode 100644 index 0000000000000..0efce302de411 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSchemaTableName.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftSchemaTableName +{ + private final String schemaName; + private final String tableName; + + @ThriftConstructor + public PrestoThriftSchemaTableName(String schemaName, String tableName) + { + this.schemaName = checkValidName(schemaName); + this.tableName = checkValidName(tableName); + } + + @ThriftField(1) + public String getSchemaName() + { + return schemaName; + } + + @ThriftField(2) + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftSchemaTableName other = (PrestoThriftSchemaTableName) obj; + return Objects.equals(this.schemaName, other.schemaName) && + Objects.equals(this.tableName, other.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaName, tableName); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", schemaName) + .add("tableName", tableName) + .toString(); + } + + public SchemaTableName toSchemaTableName() + { + return new SchemaTableName(getSchemaName(), getTableName()); + } + + public static PrestoThriftSchemaTableName fromSchemaTableName(SchemaTableName schemaTableName) + { + return new PrestoThriftSchemaTableName(schemaTableName.getSchemaName(), schemaTableName.getTableName()); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java new file mode 100644 index 0000000000000..f40babdc7f6ab --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.service.ThriftMethod; +import com.facebook.swift.service.ThriftService; +import com.google.common.util.concurrent.ListenableFuture; + +import java.io.Closeable; +import java.util.List; + +/** + * Presto Thrift service definition. + * This thrift service needs to be implemented in order to be used with Thrift Connector. + */ +@ThriftService +public interface PrestoThriftService + extends Closeable +{ + /** + * Returns available schema names. + */ + @ThriftMethod("prestoListSchemaNames") + List listSchemaNames() + throws PrestoThriftServiceException; + + /** + * Returns tables for the given schema name. + * + * @param schemaNameOrNull a structure containing schema name or {@literal null} + * @return a list of table names with corresponding schemas. If schema name is null then returns + * a list of tables for all schemas. Returns an empty list if a schema does not exist + */ + @ThriftMethod("prestoListTables") + List listTables( + @ThriftField(name = "schemaNameOrNull") PrestoThriftNullableSchemaName schemaNameOrNull) + throws PrestoThriftServiceException; + + /** + * Returns metadata for a given table. + * + * @param schemaTableName schema and table name + * @return metadata for a given table, or a {@literal null} value inside if it does not exist + */ + @ThriftMethod("prestoGetTableMetadata") + PrestoThriftNullableTableMetadata getTableMetadata( + @ThriftField(name = "schemaTableName") PrestoThriftSchemaTableName schemaTableName) + throws PrestoThriftServiceException; + + /** + * Returns a batch of splits. + * + * @param schemaTableName schema and table name + * @param desiredColumns a superset of columns to return; empty set means "no columns", {@literal null} set means "all columns" + * @param outputConstraint constraint on the returned data + * @param maxSplitCount maximum number of splits to return + * @param nextToken token from a previous split batch or {@literal null} if it is the first call + * @return a batch of splits + */ + @ThriftMethod("prestoGetSplits") + ListenableFuture getSplits( + @ThriftField(name = "schemaTableName") PrestoThriftSchemaTableName schemaTableName, + @ThriftField(name = "desiredColumns") PrestoThriftNullableColumnSet desiredColumns, + @ThriftField(name = "outputConstraint") PrestoThriftTupleDomain outputConstraint, + @ThriftField(name = "maxSplitCount") int maxSplitCount, + @ThriftField(name = "nextToken") PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException; + + /** + * Returns a batch of rows for the given split. + * + * @param splitId split id as returned in split batch + * @param columns a list of column names to return + * @param maxBytes maximum size of returned data in bytes + * @param nextToken token from a previous batch or {@literal null} if it is the first call + * @return a batch of table data + */ + @ThriftMethod("prestoGetRows") + ListenableFuture getRows( + @ThriftField(name = "splitId") PrestoThriftId splitId, + @ThriftField(name = "columns") List columns, + @ThriftField(name = "maxBytes") long maxBytes, + @ThriftField(name = "nextToken") PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException; + + @Override + void close(); +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftServiceException.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftServiceException.java new file mode 100644 index 0000000000000..fb51007111cf7 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftServiceException.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +@ThriftStruct +public final class PrestoThriftServiceException + extends RuntimeException +{ + private final boolean retryable; + + @ThriftConstructor + public PrestoThriftServiceException(String message, boolean retryable) + { + super(message); + this.retryable = retryable; + } + + @Override + @ThriftField(1) + public String getMessage() + { + return super.getMessage(); + } + + @ThriftField(2) + public boolean isRetryable() + { + return retryable; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplit.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplit.java new file mode 100644 index 0000000000000..a146d9bc727d2 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplit.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftSplit +{ + private final PrestoThriftId splitId; + private final List hosts; + + @ThriftConstructor + public PrestoThriftSplit(PrestoThriftId splitId, List hosts) + { + this.splitId = requireNonNull(splitId, "splitId is null"); + this.hosts = requireNonNull(hosts, "hosts is null"); + } + + @ThriftField(1) + public PrestoThriftId getSplitId() + { + return splitId; + } + + @ThriftField(2) + public List getHosts() + { + return hosts; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftSplit other = (PrestoThriftSplit) obj; + return Objects.equals(this.splitId, other.splitId) && + Objects.equals(this.hosts, other.hosts); + } + + @Override + public int hashCode() + { + return Objects.hash(splitId, hosts); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("splitId", splitId) + .add("hosts", hosts) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplitBatch.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplitBatch.java new file mode 100644 index 0000000000000..95f265207ea99 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftSplitBatch.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftSplitBatch +{ + private final List splits; + private final PrestoThriftId nextToken; + + @ThriftConstructor + public PrestoThriftSplitBatch(List splits, @Nullable PrestoThriftId nextToken) + { + this.splits = requireNonNull(splits, "splits is null"); + this.nextToken = nextToken; + } + + @ThriftField(1) + public List getSplits() + { + return splits; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public PrestoThriftId getNextToken() + { + return nextToken; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftSplitBatch other = (PrestoThriftSplitBatch) obj; + return Objects.equals(this.splits, other.splits) && + Objects.equals(this.nextToken, other.nextToken); + } + + @Override + public int hashCode() + { + return Objects.hash(splits, nextToken); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfSplits", splits.size()) + .add("nextToken", nextToken) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTableMetadata.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTableMetadata.java new file mode 100644 index 0000000000000..0e9fceba52da7 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTableMetadata.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; +import com.google.common.collect.ImmutableMap; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftTableMetadata +{ + private final PrestoThriftSchemaTableName schemaTableName; + private final List columns; + private final String comment; + + @ThriftConstructor + public PrestoThriftTableMetadata( + @ThriftField(name = "schemaTableName") PrestoThriftSchemaTableName schemaTableName, + @ThriftField(name = "columns") List columns, + @ThriftField(name = "comment") @Nullable String comment) + { + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.columns = requireNonNull(columns, "columns is null"); + this.comment = comment; + } + + @ThriftField(1) + public PrestoThriftSchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @ThriftField(2) + public List getColumns() + { + return columns; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public String getComment() + { + return comment; + } + + public ConnectorTableMetadata toConnectorTableMetadata(TypeManager typeManager) + { + return new ConnectorTableMetadata( + schemaTableName.toSchemaTableName(), + columnMetadata(typeManager), + ImmutableMap.of(), + Optional.ofNullable(comment)); + } + + private List columnMetadata(TypeManager typeManager) + { + return columns.stream() + .map(column -> column.toColumnMetadata(typeManager)) + .collect(toImmutableList()); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftTableMetadata other = (PrestoThriftTableMetadata) obj; + return Objects.equals(this.schemaTableName, other.schemaTableName) && + Objects.equals(this.columns, other.columns) && + Objects.equals(this.comment, other.comment); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaTableName, columns, comment); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaTableName", schemaTableName) + .add("numberOfColumns", columns.size()) + .add("comment", comment) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTupleDomain.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTupleDomain.java new file mode 100644 index 0000000000000..a1631d26e278a --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftTupleDomain.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; +import com.google.common.collect.ImmutableSet; + +import javax.annotation.Nullable; + +import java.util.Map; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftTupleDomain +{ + private final Map domains; + + @ThriftConstructor + public PrestoThriftTupleDomain(@Nullable Map domains) + { + if (domains != null) { + for (String name : domains.keySet()) { + checkValidName(name); + } + } + this.domains = domains; + } + + /** + * Return a map of column names to constraints. + */ + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public Map getDomains() + { + return domains; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftTupleDomain other = (PrestoThriftTupleDomain) obj; + return Objects.equals(this.domains, other.domains); + } + + @Override + public int hashCode() + { + return Objects.hashCode(domains); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columnsWithConstraints", domains != null ? domains.keySet() : ImmutableSet.of()) + .toString(); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigint.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigint.java new file mode 100644 index 0000000000000..8ebe4c6c33f54 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigint.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromLongBasedBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code longs} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftBigint + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final long[] longs; + + @ThriftConstructor + public PrestoThriftBigint( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "longs") @Nullable long[] longs) + { + checkArgument(sameSizeIfPresent(nulls, longs), "nulls and values must be of the same size"); + this.nulls = nulls; + this.longs = longs; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public long[] getLongs() + { + return longs; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(BIGINT.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new LongArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + longs == null ? new long[numberOfRecords] : longs); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (longs != null) { + return longs.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBigint other = (PrestoThriftBigint) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.longs, other.longs); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(longs)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromLongBasedBlock(block, BIGINT, (nulls, longs) -> bigintData(new PrestoThriftBigint(nulls, longs))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, long[] longs) + { + return nulls == null || longs == null || nulls.length == longs.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigintArray.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigintArray.java new file mode 100644 index 0000000000000..6232ba23f8f47 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBigintArray.java @@ -0,0 +1,180 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.AbstractArrayBlock; +import com.facebook.presto.spi.block.ArrayBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintArrayData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.calculateOffsets; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.sameSizeIfPresent; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.totalSize; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the number of elements in the corresponding values array. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code values} is a bigint block containing array elements one after another for all rows. + * The total number of elements in bigint block must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftBigintArray + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] sizes; + private final PrestoThriftBigint values; + + @ThriftConstructor + public PrestoThriftBigintArray( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "values") @Nullable PrestoThriftBigint values) + { + checkArgument(sameSizeIfPresent(nulls, sizes), "nulls and values must be of the same size"); + checkArgument(totalSize(nulls, sizes) == numberOfValues(values), "total number of values doesn't match expected size"); + this.nulls = nulls; + this.sizes = sizes; + this.values = values; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sizes; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftBigint getValues() + { + return values; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(desiredType.getTypeParameters().size() == 1 && BIGINT.equals(desiredType.getTypeParameters().get(0)), + "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new ArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + calculateOffsets(sizes, nulls, numberOfRecords), + values != null ? values.toBlock(BIGINT) : new LongArrayBlock(0, new boolean[] {}, new long[] {})); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (sizes != null) { + return sizes.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBigintArray other = (PrestoThriftBigintArray) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.sizes, other.sizes) && + Objects.equals(this.values, other.values); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(sizes), values); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + checkArgument(block instanceof AbstractArrayBlock, "block is not of an array type"); + AbstractArrayBlock arrayBlock = (AbstractArrayBlock) block; + int positions = arrayBlock.getPositionCount(); + if (positions == 0) { + return bigintArrayData(new PrestoThriftBigintArray(null, null, null)); + } + boolean[] nulls = null; + int[] sizes = null; + for (int position = 0; position < positions; position++) { + if (arrayBlock.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (sizes == null) { + sizes = new int[positions]; + } + sizes[position] = arrayBlock.apply((valuesBlock, startPosition, length) -> length, position); + } + } + PrestoThriftBigint values = arrayBlock + .apply((valuesBlock, startPosition, length) -> PrestoThriftBigint.fromBlock(valuesBlock), 0) + .getBigintData(); + checkState(values != null, "values must be present"); + checkState(totalSize(nulls, sizes) == values.numberOfRecords(), "unexpected number of values"); + return bigintArrayData(new PrestoThriftBigintArray(nulls, sizes, values)); + } + + private static int numberOfValues(PrestoThriftBigint values) + { + return values != null ? values.numberOfRecords() : 0; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBoolean.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBoolean.java new file mode 100644 index 0000000000000..1f7067406e76b --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftBoolean.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.ByteArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.booleanData; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code booleans} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftBoolean + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final boolean[] booleans; + + @ThriftConstructor + public PrestoThriftBoolean(@Nullable boolean[] nulls, @Nullable boolean[] booleans) + { + checkArgument(sameSizeIfPresent(nulls, booleans), "nulls and values must be of the same size"); + this.nulls = nulls; + this.booleans = booleans; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public boolean[] getBooleans() + { + return booleans; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(BOOLEAN.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new ByteArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + booleans == null ? new byte[numberOfRecords] : toByteArray(booleans)); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (booleans != null) { + return booleans.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftBoolean other = (PrestoThriftBoolean) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.booleans, other.booleans); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(booleans)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return booleanData(new PrestoThriftBoolean(null, null)); + } + boolean[] nulls = null; + boolean[] booleans = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (booleans == null) { + booleans = new boolean[positions]; + } + booleans[position] = BOOLEAN.getBoolean(block, position); + } + } + return booleanData(new PrestoThriftBoolean(nulls, booleans)); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, boolean[] booleans) + { + return nulls == null || booleans == null || nulls.length == booleans.length; + } + + private static byte[] toByteArray(boolean[] booleans) + { + byte[] bytes = new byte[booleans.length]; + for (int i = 0; i < booleans.length; i++) { + bytes[i] = booleans[i] ? (byte) 1 : (byte) 0; + } + return bytes; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftColumnData.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftColumnData.java new file mode 100644 index 0000000000000..eb687625e002f --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftColumnData.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; + +public interface PrestoThriftColumnData +{ + Block toBlock(Type desiredType); + + int numberOfRecords(); +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDate.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDate.java new file mode 100644 index 0000000000000..2440a4f12d7b4 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDate.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.IntArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.dateData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromIntBasedBlock; +import static com.facebook.presto.spi.type.DateType.DATE; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code dates} array are date values for each row represented as the number of days passed since 1970-01-01. + * If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftDate + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] dates; + + @ThriftConstructor + public PrestoThriftDate( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "dates") @Nullable int[] dates) + { + checkArgument(sameSizeIfPresent(nulls, dates), "nulls and values must be of the same size"); + this.nulls = nulls; + this.dates = dates; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getDates() + { + return dates; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(DATE.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new IntArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + dates == null ? new int[numberOfRecords] : dates); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (dates != null) { + return dates.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftDate other = (PrestoThriftDate) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.dates, other.dates); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(dates)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromIntBasedBlock(block, DATE, (nulls, ints) -> dateData(new PrestoThriftDate(nulls, ints))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, int[] dates) + { + return nulls == null || dates == null || nulls.length == dates.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDouble.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDouble.java new file mode 100644 index 0000000000000..489515d8a4706 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftDouble.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.booleanData; +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.doubleData; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Double.doubleToLongBits; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code doubles} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftDouble + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final double[] doubles; + + @ThriftConstructor + public PrestoThriftDouble(@Nullable boolean[] nulls, @Nullable double[] doubles) + { + checkArgument(sameSizeIfPresent(nulls, doubles), "nulls and values must be of the same size"); + this.nulls = nulls; + this.doubles = doubles; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public double[] getDoubles() + { + return doubles; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(DOUBLE.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + long[] longs = new long[numberOfRecords]; + if (doubles != null) { + for (int i = 0; i < numberOfRecords; i++) { + longs[i] = doubleToLongBits(doubles[i]); + } + } + return new LongArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + longs); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (doubles != null) { + return doubles.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftDouble other = (PrestoThriftDouble) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.doubles, other.doubles); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(doubles)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return booleanData(new PrestoThriftBoolean(null, null)); + } + boolean[] nulls = null; + double[] doubles = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (doubles == null) { + doubles = new double[positions]; + } + doubles[position] = DOUBLE.getDouble(block, position); + } + } + return doubleData(new PrestoThriftDouble(nulls, doubles)); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, double[] doubles) + { + return nulls == null || doubles == null || nulls.length == doubles.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftHyperLogLog.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftHyperLogLog.java new file mode 100644 index 0000000000000..51e8bbe8be4e4 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftHyperLogLog.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.hyperLogLogData; +import static com.facebook.presto.connector.thrift.api.datatypes.SliceData.fromSliceBasedBlock; +import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains encoded byte values for HyperLogLog representation as defined in airlift specification. + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftHyperLogLog + implements PrestoThriftColumnData +{ + private final SliceData sliceType; + + @ThriftConstructor + public PrestoThriftHyperLogLog( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "bytes") @Nullable byte[] bytes) + { + this.sliceType = new SliceData(nulls, sizes, bytes); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return sliceType.getNulls(); + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sliceType.getSizes(); + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public byte[] getBytes() + { + return sliceType.getBytes(); + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(HYPER_LOG_LOG.equals(desiredType), "type doesn't match: %s", desiredType); + return sliceType.toBlock(desiredType); + } + + @Override + public int numberOfRecords() + { + return sliceType.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftHyperLogLog other = (PrestoThriftHyperLogLog) obj; + return Objects.equals(this.sliceType, other.sliceType); + } + + @Override + public int hashCode() + { + return sliceType.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromSliceBasedBlock(block, HYPER_LOG_LOG, (nulls, sizes, bytes) -> hyperLogLogData(new PrestoThriftHyperLogLog(nulls, sizes, bytes))); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftInteger.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftInteger.java new file mode 100644 index 0000000000000..8be9477330fae --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftInteger.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.IntArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.integerData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromIntBasedBlock; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code ints} array are values for each row. If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftInteger + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] ints; + + @ThriftConstructor + public PrestoThriftInteger( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "ints") @Nullable int[] ints) + { + checkArgument(sameSizeIfPresent(nulls, ints), "nulls and values must be of the same size"); + this.nulls = nulls; + this.ints = ints; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getInts() + { + return ints; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(INTEGER.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new IntArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + ints == null ? new int[numberOfRecords] : ints); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (ints != null) { + return ints.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftInteger other = (PrestoThriftInteger) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.ints, other.ints); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(ints)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromIntBasedBlock(block, INTEGER, (nulls, ints) -> integerData(new PrestoThriftInteger(nulls, ints))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, int[] ints) + { + return nulls == null || ints == null || nulls.length == ints.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftJson.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftJson.java new file mode 100644 index 0000000000000..1e173fc8e18e0 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftJson.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.jsonData; +import static com.facebook.presto.connector.thrift.api.datatypes.SliceData.fromSliceBasedBlock; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains uft8 encoded byte values for string representation of json. + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftJson + implements PrestoThriftColumnData +{ + private final SliceData sliceType; + + @ThriftConstructor + public PrestoThriftJson( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "bytes") @Nullable byte[] bytes) + { + this.sliceType = new SliceData(nulls, sizes, bytes); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return sliceType.getNulls(); + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sliceType.getSizes(); + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public byte[] getBytes() + { + return sliceType.getBytes(); + } + + @Override + public Block toBlock(Type desiredType) + { + return sliceType.toBlock(desiredType); + } + + @Override + public int numberOfRecords() + { + return sliceType.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftJson other = (PrestoThriftJson) obj; + return Objects.equals(this.sliceType, other.sliceType); + } + + @Override + public int hashCode() + { + return sliceType.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block, Type type) + { + return fromSliceBasedBlock(block, type, (nulls, sizes, bytes) -> jsonData(new PrestoThriftJson(nulls, sizes, bytes))); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTimestamp.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTimestamp.java new file mode 100644 index 0000000000000..b08267bcbed41 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTimestamp.java @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.LongArrayBlock; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.timestampData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.fromLongBasedBlock; +import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Elements of {@code timestamps} array are values for each row represented as the number of milliseconds passed since 1970-01-01T00:00:00 UTC. + * If row is null then value is ignored. + */ +@ThriftStruct +public final class PrestoThriftTimestamp + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final long[] timestamps; + + @ThriftConstructor + public PrestoThriftTimestamp( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "timestamps") @Nullable long[] timestamps) + { + checkArgument(sameSizeIfPresent(nulls, timestamps), "nulls and values must be of the same size"); + this.nulls = nulls; + this.timestamps = timestamps; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public long[] getTimestamps() + { + return timestamps; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(TIMESTAMP.equals(desiredType), "type doesn't match: %s", desiredType); + int numberOfRecords = numberOfRecords(); + return new LongArrayBlock( + numberOfRecords, + nulls == null ? new boolean[numberOfRecords] : nulls, + timestamps == null ? new long[numberOfRecords] : timestamps); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (timestamps != null) { + return timestamps.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftTimestamp other = (PrestoThriftTimestamp) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.timestamps, other.timestamps); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(timestamps)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block) + { + return fromLongBasedBlock(block, TIMESTAMP, (nulls, longs) -> timestampData(new PrestoThriftTimestamp(nulls, longs))); + } + + private static boolean sameSizeIfPresent(boolean[] nulls, long[] timestamps) + { + return nulls == null || timestamps == null || nulls.length == timestamps.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTypeUtils.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTypeUtils.java new file mode 100644 index 0000000000000..5825334e0863f --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftTypeUtils.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; + +import java.util.function.BiFunction; + +final class PrestoThriftTypeUtils +{ + private PrestoThriftTypeUtils() + { + } + + public static PrestoThriftBlock fromLongBasedBlock(Block block, Type type, BiFunction result) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return result.apply(null, null); + } + boolean[] nulls = null; + long[] longs = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (longs == null) { + longs = new long[positions]; + } + longs[position] = type.getLong(block, position); + } + } + return result.apply(nulls, longs); + } + + public static PrestoThriftBlock fromIntBasedBlock(Block block, Type type, BiFunction result) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return result.apply(null, null); + } + boolean[] nulls = null; + int[] ints = null; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + if (ints == null) { + ints = new int[positions]; + } + ints[position] = (int) type.getLong(block, position); + } + } + return result.apply(nulls, ints); + } + + public static int totalSize(boolean[] nulls, int[] sizes) + { + int numberOfRecords; + if (nulls != null) { + numberOfRecords = nulls.length; + } + else if (sizes != null) { + numberOfRecords = sizes.length; + } + else { + numberOfRecords = 0; + } + int total = 0; + for (int i = 0; i < numberOfRecords; i++) { + if (nulls == null || !nulls[i]) { + total += sizes[i]; + } + } + return total; + } + + public static int[] calculateOffsets(int[] sizes, boolean[] nulls, int totalRecords) + { + if (sizes == null) { + return new int[totalRecords + 1]; + } + int[] offsets = new int[totalRecords + 1]; + offsets[0] = 0; + for (int i = 0; i < totalRecords; i++) { + int size = nulls != null && nulls[i] ? 0 : sizes[i]; + offsets[i + 1] = offsets[i] + size; + } + return offsets; + } + + public static boolean sameSizeIfPresent(boolean[] nulls, int[] sizes) + { + return nulls == null || sizes == null || nulls.length == sizes.length; + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftVarchar.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftVarchar.java new file mode 100644 index 0000000000000..063d9360bc532 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/PrestoThriftVarchar.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.VarcharType; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.varcharData; +import static com.facebook.presto.connector.thrift.api.datatypes.SliceData.fromSliceBasedBlock; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Elements of {@code nulls} array determine if a value for a corresponding row is null. + * Each elements of {@code sizes} array contains the length in bytes for the corresponding element. + * If row is null then the corresponding element in {@code sizes} is ignored. + * {@code bytes} array contains uft8 encoded byte values. + * Values for all rows are written to {@code bytes} array one after another. + * The total number of bytes must be equal to the sum of all sizes. + */ +@ThriftStruct +public final class PrestoThriftVarchar + implements PrestoThriftColumnData +{ + private final SliceData sliceType; + + @ThriftConstructor + public PrestoThriftVarchar( + @ThriftField(name = "nulls") @Nullable boolean[] nulls, + @ThriftField(name = "sizes") @Nullable int[] sizes, + @ThriftField(name = "bytes") @Nullable byte[] bytes) + { + this.sliceType = new SliceData(nulls, sizes, bytes); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public boolean[] getNulls() + { + return sliceType.getNulls(); + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public int[] getSizes() + { + return sliceType.getSizes(); + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public byte[] getBytes() + { + return sliceType.getBytes(); + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(desiredType.getClass() == VarcharType.class, "type doesn't match: %s", desiredType); + return sliceType.toBlock(desiredType); + } + + @Override + public int numberOfRecords() + { + return sliceType.numberOfRecords(); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftVarchar other = (PrestoThriftVarchar) obj; + return Objects.equals(this.sliceType, other.sliceType); + } + + @Override + public int hashCode() + { + return sliceType.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromBlock(Block block, Type type) + { + return fromSliceBasedBlock(block, type, (nulls, sizes, bytes) -> varcharData(new PrestoThriftVarchar(nulls, sizes, bytes))); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/SliceData.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/SliceData.java new file mode 100644 index 0000000000000..20aedd9f47013 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/datatypes/SliceData.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.VariableWidthBlock; +import com.facebook.presto.spi.type.Type; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.calculateOffsets; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.sameSizeIfPresent; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftTypeUtils.totalSize; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +final class SliceData + implements PrestoThriftColumnData +{ + private final boolean[] nulls; + private final int[] sizes; + private final byte[] bytes; + + public SliceData(@Nullable boolean[] nulls, @Nullable int[] sizes, @Nullable byte[] bytes) + { + checkArgument(sameSizeIfPresent(nulls, sizes), "nulls and values must be of the same size"); + checkArgument(totalSize(nulls, sizes) == (bytes != null ? bytes.length : 0), "total bytes size doesn't match expected size"); + this.nulls = nulls; + this.sizes = sizes; + this.bytes = bytes; + } + + @Nullable + public boolean[] getNulls() + { + return nulls; + } + + @Nullable + public int[] getSizes() + { + return sizes; + } + + @Nullable + public byte[] getBytes() + { + return bytes; + } + + @Override + public Block toBlock(Type desiredType) + { + checkArgument(desiredType.getJavaType() == Slice.class, "type doesn't match: %s", desiredType); + Slice values = bytes == null ? Slices.EMPTY_SLICE : Slices.wrappedBuffer(bytes); + int numberOfRecords = numberOfRecords(); + return new VariableWidthBlock( + numberOfRecords, + values, + calculateOffsets(sizes, nulls, numberOfRecords), + nulls == null ? new boolean[numberOfRecords] : nulls); + } + + @Override + public int numberOfRecords() + { + if (nulls != null) { + return nulls.length; + } + if (sizes != null) { + return sizes.length; + } + return 0; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + SliceData other = (SliceData) obj; + return Arrays.equals(this.nulls, other.nulls) && + Arrays.equals(this.sizes, other.sizes) && + Arrays.equals(this.bytes, other.bytes); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(nulls), Arrays.hashCode(sizes), Arrays.hashCode(bytes)); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRecords", numberOfRecords()) + .toString(); + } + + public static PrestoThriftBlock fromSliceBasedBlock(Block block, Type type, CreateSliceThriftBlockFunction create) + { + int positions = block.getPositionCount(); + if (positions == 0) { + return create.apply(null, null, null); + } + boolean[] nulls = null; + int[] sizes = null; + byte[] bytes = null; + int bytesIndex = 0; + for (int position = 0; position < positions; position++) { + if (block.isNull(position)) { + if (nulls == null) { + nulls = new boolean[positions]; + } + nulls[position] = true; + } + else { + Slice value = type.getSlice(block, position); + if (sizes == null) { + sizes = new int[positions]; + int totalBytes = totalSliceBytes(block); + if (totalBytes > 0) { + bytes = new byte[totalBytes]; + } + } + int length = value.length(); + sizes[position] = length; + if (length > 0) { + checkState(bytes != null); + value.getBytes(0, bytes, bytesIndex, length); + bytesIndex += length; + } + } + } + checkState(bytes == null || bytesIndex == bytes.length); + return create.apply(nulls, sizes, bytes); + } + + private static int totalSliceBytes(Block block) + { + int totalBytes = 0; + int positions = block.getPositionCount(); + for (int position = 0; position < positions; position++) { + totalBytes += block.getSliceLength(position); + } + return totalBytes; + } + + public interface CreateSliceThriftBlockFunction + { + PrestoThriftBlock apply(boolean[] nulls, int[] sizes, byte[] bytes); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftAllOrNoneValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftAllOrNoneValueSet.java new file mode 100644 index 0000000000000..2da881f5be903 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftAllOrNoneValueSet.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.spi.predicate.AllOrNoneValueSet; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import static com.google.common.base.MoreObjects.toStringHelper; + +@ThriftStruct +public final class PrestoThriftAllOrNoneValueSet +{ + private final boolean all; + + @ThriftConstructor + public PrestoThriftAllOrNoneValueSet(boolean all) + { + this.all = all; + } + + @ThriftField(1) + public boolean isAll() + { + return all; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftAllOrNoneValueSet other = (PrestoThriftAllOrNoneValueSet) obj; + return this.all == other.all; + } + + @Override + public int hashCode() + { + return Boolean.hashCode(all); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("all", all) + .toString(); + } + + public static PrestoThriftAllOrNoneValueSet fromAllOrNoneValueSet(AllOrNoneValueSet valueSet) + { + return new PrestoThriftAllOrNoneValueSet(valueSet.isAll()); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftEquatableValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftEquatableValueSet.java new file mode 100644 index 0000000000000..364ba43d032f1 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftEquatableValueSet.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.predicate.EquatableValueSet; +import com.facebook.presto.spi.predicate.EquatableValueSet.ValueEntry; +import com.facebook.presto.spi.type.Type; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftEquatableValueSet +{ + private final boolean whiteList; + private final List values; + + @ThriftConstructor + public PrestoThriftEquatableValueSet(boolean whiteList, List values) + { + this.whiteList = whiteList; + this.values = requireNonNull(values, "values are null"); + } + + @ThriftField(1) + public boolean isWhiteList() + { + return whiteList; + } + + @ThriftField(2) + public List getValues() + { + return values; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftEquatableValueSet other = (PrestoThriftEquatableValueSet) obj; + return this.whiteList == other.whiteList && + Objects.equals(this.values, other.values); + } + + @Override + public int hashCode() + { + return Objects.hash(whiteList, values); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("whiteList", whiteList) + .add("values", values) + .toString(); + } + + public static PrestoThriftEquatableValueSet fromEquatableValueSet(EquatableValueSet valueSet) + { + Type type = valueSet.getType(); + Set values = valueSet.getEntries(); + List thriftValues = new ArrayList<>(values.size()); + for (ValueEntry value : values) { + checkState(type.equals(value.getType()), "ValueEntrySet has elements of different types: %s vs %s", type, value.getType()); + thriftValues.add(fromBlock(value.getBlock(), type)); + } + return new PrestoThriftEquatableValueSet(valueSet.isWhiteList(), thriftValues); + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftRangeValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftRangeValueSet.java new file mode 100644 index 0000000000000..d9391f975be1f --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftRangeValueSet.java @@ -0,0 +1,254 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.predicate.Marker; +import com.facebook.presto.spi.predicate.Marker.Bound; +import com.facebook.presto.spi.predicate.Range; +import com.facebook.presto.spi.predicate.SortedRangeSet; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftEnum; +import com.facebook.swift.codec.ThriftEnumValue; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.fromBound; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftMarker.fromMarker; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public final class PrestoThriftRangeValueSet +{ + private final List ranges; + + @ThriftConstructor + public PrestoThriftRangeValueSet(@ThriftField(name = "ranges") List ranges) + { + this.ranges = requireNonNull(ranges, "ranges is null"); + } + + @ThriftField(1) + public List getRanges() + { + return ranges; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftRangeValueSet other = (PrestoThriftRangeValueSet) obj; + return Objects.equals(this.ranges, other.ranges); + } + + @Override + public int hashCode() + { + return Objects.hashCode(ranges); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("numberOfRanges", ranges.size()) + .toString(); + } + + public static PrestoThriftRangeValueSet fromSortedRangeSet(SortedRangeSet valueSet) + { + List ranges = valueSet.getOrderedRanges().stream() + .map(PrestoThriftRange::fromRange) + .collect(toImmutableList()); + return new PrestoThriftRangeValueSet(ranges); + } + + @ThriftEnum + public enum PrestoThriftBound + { + BELOW(1), // lower than the value, but infinitesimally close to the value + EXACTLY(2), // exactly the value + ABOVE(3); // higher than the value, but infinitesimally close to the value + + private final int value; + + PrestoThriftBound(int value) + { + this.value = value; + } + + @ThriftEnumValue + public int getValue() + { + return value; + } + + public static PrestoThriftBound fromBound(Bound bound) + { + switch (bound) { + case BELOW: + return BELOW; + case EXACTLY: + return EXACTLY; + case ABOVE: + return ABOVE; + default: + throw new IllegalArgumentException("Unknown bound: " + bound); + } + } + } + + /** + * LOWER UNBOUNDED is specified with an empty value and a ABOVE bound + * UPPER UNBOUNDED is specified with an empty value and a BELOW bound + */ + @ThriftStruct + public static final class PrestoThriftMarker + { + private final PrestoThriftBlock value; + private final PrestoThriftBound bound; + + @ThriftConstructor + public PrestoThriftMarker(@Nullable PrestoThriftBlock value, PrestoThriftBound bound) + { + checkArgument(value == null || value.numberOfRecords() == 1, "value must contain exactly one record when present"); + this.value = value; + this.bound = requireNonNull(bound, "bound is null"); + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftBlock getValue() + { + return value; + } + + @ThriftField(2) + public PrestoThriftBound getBound() + { + return bound; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftMarker other = (PrestoThriftMarker) obj; + return Objects.equals(this.value, other.value) && + Objects.equals(this.bound, other.bound); + } + + @Override + public int hashCode() + { + return Objects.hash(value, bound); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("value", value) + .add("bound", bound) + .toString(); + } + + public static PrestoThriftMarker fromMarker(Marker marker) + { + PrestoThriftBlock value = marker.getValueBlock().isPresent() ? fromBlock(marker.getValueBlock().get(), marker.getType()) : null; + return new PrestoThriftMarker(value, fromBound(marker.getBound())); + } + } + + @ThriftStruct + public static final class PrestoThriftRange + { + private final PrestoThriftMarker low; + private final PrestoThriftMarker high; + + @ThriftConstructor + public PrestoThriftRange(PrestoThriftMarker low, PrestoThriftMarker high) + { + this.low = requireNonNull(low, "low is null"); + this.high = requireNonNull(high, "high is null"); + } + + @ThriftField(1) + public PrestoThriftMarker getLow() + { + return low; + } + + @ThriftField(2) + public PrestoThriftMarker getHigh() + { + return high; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftRange other = (PrestoThriftRange) obj; + return Objects.equals(this.low, other.low) && + Objects.equals(this.high, other.high); + } + + @Override + public int hashCode() + { + return Objects.hash(low, high); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("low", low) + .add("high", high) + .toString(); + } + + public static PrestoThriftRange fromRange(Range range) + { + return new PrestoThriftRange(fromMarker(range.getLow()), fromMarker(range.getHigh())); + } + } +} diff --git a/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftValueSet.java b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftValueSet.java new file mode 100644 index 0000000000000..32dce545f6699 --- /dev/null +++ b/presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/valuesets/PrestoThriftValueSet.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.spi.predicate.AllOrNoneValueSet; +import com.facebook.presto.spi.predicate.EquatableValueSet; +import com.facebook.presto.spi.predicate.SortedRangeSet; +import com.facebook.presto.spi.predicate.ValueSet; +import com.facebook.swift.codec.ThriftConstructor; +import com.facebook.swift.codec.ThriftField; +import com.facebook.swift.codec.ThriftStruct; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftAllOrNoneValueSet.fromAllOrNoneValueSet; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftEquatableValueSet.fromEquatableValueSet; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.fromSortedRangeSet; +import static com.facebook.swift.codec.ThriftField.Requiredness.OPTIONAL; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +@ThriftStruct +public final class PrestoThriftValueSet +{ + private final PrestoThriftAllOrNoneValueSet allOrNoneValueSet; + private final PrestoThriftEquatableValueSet equatableValueSet; + private final PrestoThriftRangeValueSet rangeValueSet; + + @ThriftConstructor + public PrestoThriftValueSet( + @Nullable PrestoThriftAllOrNoneValueSet allOrNoneValueSet, + @Nullable PrestoThriftEquatableValueSet equatableValueSet, + @Nullable PrestoThriftRangeValueSet rangeValueSet) + { + checkArgument(isExactlyOneNonNull(allOrNoneValueSet, equatableValueSet, rangeValueSet), "exactly one value set must be present"); + this.allOrNoneValueSet = allOrNoneValueSet; + this.equatableValueSet = equatableValueSet; + this.rangeValueSet = rangeValueSet; + } + + @Nullable + @ThriftField(value = 1, requiredness = OPTIONAL) + public PrestoThriftAllOrNoneValueSet getAllOrNoneValueSet() + { + return allOrNoneValueSet; + } + + @Nullable + @ThriftField(value = 2, requiredness = OPTIONAL) + public PrestoThriftEquatableValueSet getEquatableValueSet() + { + return equatableValueSet; + } + + @Nullable + @ThriftField(value = 3, requiredness = OPTIONAL) + public PrestoThriftRangeValueSet getRangeValueSet() + { + return rangeValueSet; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + PrestoThriftValueSet other = (PrestoThriftValueSet) obj; + return Objects.equals(this.allOrNoneValueSet, other.allOrNoneValueSet) && + Objects.equals(this.equatableValueSet, other.equatableValueSet) && + Objects.equals(this.rangeValueSet, other.rangeValueSet); + } + + @Override + public int hashCode() + { + return Objects.hash(allOrNoneValueSet, equatableValueSet, rangeValueSet); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("valueSet", firstNonNull(allOrNoneValueSet, equatableValueSet, rangeValueSet)) + .toString(); + } + + public static PrestoThriftValueSet fromValueSet(ValueSet valueSet) + { + if (valueSet.getClass() == AllOrNoneValueSet.class) { + return new PrestoThriftValueSet( + fromAllOrNoneValueSet((AllOrNoneValueSet) valueSet), + null, + null); + } + else if (valueSet.getClass() == EquatableValueSet.class) { + return new PrestoThriftValueSet( + null, + fromEquatableValueSet((EquatableValueSet) valueSet), + null); + } + else if (valueSet.getClass() == SortedRangeSet.class) { + return new PrestoThriftValueSet( + null, + null, + fromSortedRangeSet((SortedRangeSet) valueSet)); + } + else { + throw new IllegalArgumentException("Unknown implementation of a value set: " + valueSet.getClass()); + } + } + + private static boolean isExactlyOneNonNull(Object a, Object b, Object c) + { + return a != null && b == null && c == null || + a == null && b != null && c == null || + a == null && b == null && c != null; + } + + private static Object firstNonNull(Object a, Object b, Object c) + { + if (a != null) { + return a; + } + if (b != null) { + return b; + } + if (c != null) { + return c; + } + throw new IllegalArgumentException("All arguments are null"); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestNameValidationUtils.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestNameValidationUtils.java new file mode 100644 index 0000000000000..31b922425cb6d --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestNameValidationUtils.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.NameValidationUtils.checkValidName; +import static org.testng.Assert.assertThrows; + +public class TestNameValidationUtils +{ + @Test + public void testCheckValidColumnName() + throws Exception + { + checkValidName("abc01_def2"); + assertThrows(() -> checkValidName(null)); + assertThrows(() -> checkValidName("")); + assertThrows(() -> checkValidName("Abc")); + assertThrows(() -> checkValidName("0abc")); + assertThrows(() -> checkValidName("_abc")); + assertThrows(() -> checkValidName("aBc")); + assertThrows(() -> checkValidName("ab-c")); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestPrestoThriftId.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestPrestoThriftId.java new file mode 100644 index 0000000000000..695a4e41d25e2 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestPrestoThriftId.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftId.summarize; +import static org.testng.Assert.assertEquals; + +public class TestPrestoThriftId +{ + @Test + public void testSummarize() + throws Exception + { + assertEquals(summarize(bytes()), ""); + assertEquals(summarize(bytes(1)), "01"); + assertEquals(summarize(bytes(255, 254, 253, 252, 251, 250, 249)), "FFFEFDFCFBFAF9"); + assertEquals(summarize(bytes(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 249, 250, 251, 252, 253, 254, 255)), + "00010203040506070809F9FAFBFCFDFEFF"); + assertEquals(summarize(bytes(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 247, 248, 249, 250, 251, 252, 253, 254, 255)), + "0001020304050607..F8F9FAFBFCFDFEFF"); + } + + private static byte[] bytes(int... values) + { + int length = values.length; + byte[] result = new byte[length]; + for (int i = 0; i < length; i++) { + result[i] = (byte) values[i]; + } + return result; + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestReadWrite.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestReadWrite.java new file mode 100644 index 0000000000000..b1153a56c394e --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/TestReadWrite.java @@ -0,0 +1,469 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api; + +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.VarcharType; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.stats.cardinality.HyperLogLog; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DateType.DATE; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.spi.type.VarcharType.createVarcharType; +import static com.facebook.presto.type.JsonType.JSON; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestReadWrite +{ + private static final double NULL_FRACTION = 0.1; + private static final int MAX_VARCHAR_GENERATED_LENGTH = 64; + private static final char[] SYMBOLS; + private static final long MIN_GENERATED_TIMESTAMP; + private static final long MAX_GENERATED_TIMESTAMP; + private static final int MIN_GENERATED_DATE; + private static final int MAX_GENERATED_DATE; + private static final int MAX_GENERATED_JSON_KEY_LENGTH = 8; + private static final int HYPER_LOG_LOG_BUCKETS = 128; + private static final int MAX_HYPER_LOG_LOG_ELEMENTS = 32; + private static final int MAX_ARRAY_GENERATED_LENGTH = 64; + private final AtomicLong singleValueSeedGenerator = new AtomicLong(762103512L); + private final AtomicLong columnDataSeedGenerator = new AtomicLong(762103512L); + private final List columns = ImmutableList.of( + new IntegerColumn(), + new BigintColumn(), + new DoubleColumn(), + new VarcharColumn(createUnboundedVarcharType()), + new VarcharColumn(createVarcharType(MAX_VARCHAR_GENERATED_LENGTH / 2)), + new BooleanColumn(), + new DateColumn(), + new TimestampColumn(), + new JsonColumn(), + new HyperLogLogColumn(), + new BigintArrayColumn()); + + static { + char[] symbols = new char[2 * 26 + 10]; + int next = 0; + for (char ch = 'A'; ch <= 'Z'; ch++) { + symbols[next++] = ch; + } + for (char ch = 'a'; ch <= 'z'; ch++) { + symbols[next++] = ch; + } + for (char ch = '0'; ch <= '9'; ch++) { + symbols[next++] = ch; + } + SYMBOLS = symbols; + + Calendar calendar = Calendar.getInstance(); + + calendar.set(2000, Calendar.JANUARY, 1); + MIN_GENERATED_TIMESTAMP = calendar.getTimeInMillis(); + MIN_GENERATED_DATE = toIntExact(MILLISECONDS.toDays(MIN_GENERATED_TIMESTAMP)); + + calendar.set(2020, Calendar.DECEMBER, 31); + MAX_GENERATED_TIMESTAMP = calendar.getTimeInMillis(); + MAX_GENERATED_DATE = toIntExact(MILLISECONDS.toDays(MAX_GENERATED_TIMESTAMP)); + } + + @Test(invocationCount = 20) + public void testReadWriteSingleValue() + throws Exception + { + testReadWrite(new Random(singleValueSeedGenerator.incrementAndGet()), 1); + } + + @Test(invocationCount = 20) + public void testReadWriteColumnData() + throws Exception + { + Random random = new Random(columnDataSeedGenerator.incrementAndGet()); + int records = random.nextInt(10000) + 10000; + testReadWrite(random, records); + } + + private void testReadWrite(Random random, int records) + throws Exception + { + // generate columns data + List inputBlocks = new ArrayList<>(columns.size()); + for (ColumnDefinition column : columns) { + inputBlocks.add(generateColumn(column, random, records)); + } + + // convert column data to thrift ("write step") + List columnBlocks = new ArrayList<>(columns.size()); + for (int i = 0; i < columns.size(); i++) { + columnBlocks.add(fromBlock(inputBlocks.get(i), columns.get(i).getType())); + } + PrestoThriftPageResult batch = new PrestoThriftPageResult(columnBlocks, records, null); + + // convert thrift data to page/blocks ("read step") + Page page = batch.toPage(columns.stream().map(ColumnDefinition::getType).collect(toImmutableList())); + + // compare the result with original input + assertNotNull(page); + assertEquals(page.getChannelCount(), columns.size()); + for (int i = 0; i < columns.size(); i++) { + Block actual = page.getBlock(i); + Block expected = inputBlocks.get(i); + assertBlock(actual, expected, columns.get(i)); + } + } + + private static Block generateColumn(ColumnDefinition column, Random random, int records) + { + BlockBuilder builder = column.getType().createBlockBuilder(new BlockBuilderStatus(), records); + for (int i = 0; i < records; i++) { + if (random.nextDouble() < NULL_FRACTION) { + builder.appendNull(); + } + else { + column.writeNextRandomValue(random, builder); + } + } + return builder.build(); + } + + private static void assertBlock(Block actual, Block expected, ColumnDefinition columnDefinition) + { + assertEquals(actual.getPositionCount(), expected.getPositionCount()); + int positions = actual.getPositionCount(); + for (int i = 0; i < positions; i++) { + Object actualValue = columnDefinition.extractValue(actual, i); + Object expectedValue = columnDefinition.extractValue(expected, i); + assertEquals(actualValue, expectedValue); + } + } + + private static String nextString(Random random) + { + return nextString(random, MAX_VARCHAR_GENERATED_LENGTH); + } + + private static String nextString(Random random, int maxLength) + { + int size = random.nextInt(maxLength); + char[] result = new char[size]; + for (int i = 0; i < size; i++) { + result[i] = SYMBOLS[random.nextInt(SYMBOLS.length)]; + } + return new String(result); + } + + private static long nextTimestamp(Random random) + { + return MIN_GENERATED_TIMESTAMP + (long) (random.nextDouble() * (MAX_GENERATED_TIMESTAMP - MIN_GENERATED_TIMESTAMP)); + } + + private static int nextDate(Random random) + { + return MIN_GENERATED_DATE + random.nextInt(MAX_GENERATED_DATE - MIN_GENERATED_DATE); + } + + private static Slice nextHyperLogLog(Random random) + { + HyperLogLog hll = HyperLogLog.newInstance(HYPER_LOG_LOG_BUCKETS); + int size = random.nextInt(MAX_HYPER_LOG_LOG_ELEMENTS); + for (int i = 0; i < size; i++) { + hll.add(random.nextLong()); + } + return hll.serialize(); + } + + private static void generateBigintArray(Random random, BlockBuilder parentBuilder) + { + int numberOfElements = random.nextInt(MAX_ARRAY_GENERATED_LENGTH); + BlockBuilder builder = parentBuilder.beginBlockEntry(); + for (int i = 0; i < numberOfElements; i++) { + if (random.nextDouble() < NULL_FRACTION) { + builder.appendNull(); + } + else { + builder.writeLong(random.nextLong()); + } + } + parentBuilder.closeEntry(); + } + + private abstract static class ColumnDefinition + { + private final Type type; + + public ColumnDefinition(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + public Type getType() + { + return type; + } + + abstract Object extractValue(Block block, int position); + + abstract void writeNextRandomValue(Random random, BlockBuilder builder); + } + + private static final class IntegerColumn + extends ColumnDefinition + { + public IntegerColumn() + { + super(INTEGER); + } + + @Override + Object extractValue(Block block, int position) + { + return INTEGER.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + INTEGER.writeLong(builder, random.nextInt()); + } + } + + private static final class BigintColumn + extends ColumnDefinition + { + public BigintColumn() + { + super(BIGINT); + } + + @Override + Object extractValue(Block block, int position) + { + return BIGINT.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + BIGINT.writeLong(builder, random.nextLong()); + } + } + + private static final class DoubleColumn + extends ColumnDefinition + { + public DoubleColumn() + { + super(DOUBLE); + } + + @Override + Object extractValue(Block block, int position) + { + return DOUBLE.getDouble(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + DOUBLE.writeDouble(builder, random.nextDouble()); + } + } + + private static final class VarcharColumn + extends ColumnDefinition + { + private final VarcharType varcharType; + + public VarcharColumn(VarcharType varcharType) + { + super(varcharType); + this.varcharType = requireNonNull(varcharType, "varcharType is null"); + } + + @Override + Object extractValue(Block block, int position) + { + return varcharType.getSlice(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + varcharType.writeString(builder, nextString(random)); + } + } + + private static final class BooleanColumn + extends ColumnDefinition + { + public BooleanColumn() + { + super(BOOLEAN); + } + + @Override + Object extractValue(Block block, int position) + { + return BOOLEAN.getBoolean(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + BOOLEAN.writeBoolean(builder, random.nextBoolean()); + } + } + + private static final class DateColumn + extends ColumnDefinition + { + public DateColumn() + { + super(DATE); + } + + @Override + Object extractValue(Block block, int position) + { + return DATE.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + DATE.writeLong(builder, nextDate(random)); + } + } + + private static final class TimestampColumn + extends ColumnDefinition + { + public TimestampColumn() + { + super(TIMESTAMP); + } + + @Override + Object extractValue(Block block, int position) + { + return TIMESTAMP.getLong(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + TIMESTAMP.writeLong(builder, nextTimestamp(random)); + } + } + + private static final class JsonColumn + extends ColumnDefinition + { + public JsonColumn() + { + super(JSON); + } + + @Override + Object extractValue(Block block, int position) + { + return JSON.getSlice(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + String json = String.format("{\"%s\": %d, \"%s\": \"%s\"}", + nextString(random, MAX_GENERATED_JSON_KEY_LENGTH), + random.nextInt(), + nextString(random, MAX_GENERATED_JSON_KEY_LENGTH), + random.nextInt()); + JSON.writeString(builder, json); + } + } + + private static final class HyperLogLogColumn + extends ColumnDefinition + { + public HyperLogLogColumn() + { + super(HYPER_LOG_LOG); + } + + @Override + Object extractValue(Block block, int position) + { + return HYPER_LOG_LOG.getSlice(block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + HYPER_LOG_LOG.writeSlice(builder, nextHyperLogLog(random)); + } + } + + private static final class BigintArrayColumn + extends ColumnDefinition + { + private final ArrayType arrayType; + + public BigintArrayColumn() + { + this(new ArrayType(BIGINT)); + } + + private BigintArrayColumn(ArrayType arrayType) + { + super(arrayType); + this.arrayType = requireNonNull(arrayType, "arrayType is null"); + } + + @Override + Object extractValue(Block block, int position) + { + return arrayType.getObjectValue(null, block, position); + } + + @Override + void writeNextRandomValue(Random random, BlockBuilder builder) + { + generateBigintArray(random, builder); + } + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/datatypes/TestPrestoThriftBigint.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/datatypes/TestPrestoThriftBigint.java new file mode 100644 index 0000000000000..625e1a04a6135 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/datatypes/TestPrestoThriftBigint.java @@ -0,0 +1,203 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.datatypes; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.block.BlockBuilderStatus; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintData; +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.integerData; +import static com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigint.fromBlock; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static java.util.Collections.unmodifiableList; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +public class TestPrestoThriftBigint +{ + @Test + public void testReadBlock() + throws Exception + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {false, true, false, false, false, false, true}, + new long[] {2, 0, 1, 3, 8, 4, 0} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(2L, null, 1L, 3L, 8L, 4L, null)); + } + + @Test + public void testReadBlockAllNullsOption1() + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {true, true, true, true, true, true, true}, + null + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(null, null, null, null, null, null, null)); + } + + @Test + public void testReadBlockAllNullsOption2() + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {true, true, true, true, true, true, true}, + new long[] {0, 0, 0, 0, 0, 0, 0} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(null, null, null, null, null, null, null)); + } + + @Test + public void testReadBlockAllNonNullOption1() + throws Exception + { + PrestoThriftBlock columnsData = longColumn( + null, + new long[] {2, 7, 1, 3, 8, 4, 5} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(2L, 7L, 1L, 3L, 8L, 4L, 5L)); + } + + @Test + public void testReadBlockAllNonNullOption2() + throws Exception + { + PrestoThriftBlock columnsData = longColumn( + new boolean[] {false, false, false, false, false, false, false}, + new long[] {2, 7, 1, 3, 8, 4, 5} + ); + Block actual = columnsData.toBlock(BIGINT); + assertBlockEquals(actual, list(2L, 7L, 1L, 3L, 8L, 4L, 5L)); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testReadBlockWrongActualType() + throws Exception + { + PrestoThriftBlock columnsData = integerData(new PrestoThriftInteger(null, null)); + columnsData.toBlock(BIGINT); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testReadBlockWrongDesiredType() + throws Exception + { + PrestoThriftBlock columnsData = longColumn(null, null); + columnsData.toBlock(INTEGER); + } + + @Test + public void testWriteBlockAlternating() + throws Exception + { + Block source = longBlock(1, null, 2, null, 3, null, 4, null, 5, null, 6, null, 7, null); + PrestoThriftBlock column = fromBlock(source); + assertNotNull(column.getBigintData()); + assertEquals(column.getBigintData().getNulls(), + new boolean[] {false, true, false, true, false, true, false, true, false, true, false, true, false, true}); + assertEquals(column.getBigintData().getLongs(), + new long[] {1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0}); + } + + @Test + public void testWriteBlockAllNulls() + throws Exception + { + Block source = longBlock(null, null, null, null, null); + PrestoThriftBlock column = fromBlock(source); + assertNotNull(column.getBigintData()); + assertEquals(column.getBigintData().getNulls(), new boolean[] {true, true, true, true, true}); + assertNull(column.getBigintData().getLongs()); + } + + @Test + public void testWriteBlockAllNonNull() + throws Exception + { + Block source = longBlock(1, 2, 3, 4, 5); + PrestoThriftBlock column = fromBlock(source); + assertNotNull(column.getBigintData()); + assertNull(column.getBigintData().getNulls()); + assertEquals(column.getBigintData().getLongs(), new long[] {1, 2, 3, 4, 5}); + } + + @Test + public void testWriteBlockEmpty() + throws Exception + { + PrestoThriftBlock column = fromBlock(longBlock()); + assertNotNull(column.getBigintData()); + assertNull(column.getBigintData().getNulls()); + assertNull(column.getBigintData().getLongs()); + } + + @Test + public void testWriteBlockSingleValue() + throws Exception + { + PrestoThriftBlock column = fromBlock(longBlock(1)); + assertNotNull(column.getBigintData()); + assertNull(column.getBigintData().getNulls()); + assertEquals(column.getBigintData().getLongs(), new long[] {1}); + } + + private void assertBlockEquals(Block block, List expected) + { + assertEquals(block.getPositionCount(), expected.size()); + for (int i = 0; i < expected.size(); i++) { + if (expected.get(i) == null) { + assertTrue(block.isNull(i)); + } + else { + assertEquals(block.getLong(i, 0), expected.get(i).longValue()); + } + } + } + + private static Block longBlock(Integer... values) + { + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), values.length); + for (Integer value : values) { + if (value == null) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(value).closeEntry(); + } + } + return blockBuilder.build(); + } + + private static PrestoThriftBlock longColumn(boolean[] nulls, long[] longs) + { + return bigintData(new PrestoThriftBigint(nulls, longs)); + } + + private static List list(Long... values) + { + return unmodifiableList(Arrays.asList(values)); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftAllOrNoneValueSet.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftAllOrNoneValueSet.java new file mode 100644 index 0000000000000..d1ee43697ad7d --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftAllOrNoneValueSet.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.spi.predicate.ValueSet; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestPrestoThriftAllOrNoneValueSet +{ + @Test + public void testFromValueSetAll() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.all(HYPER_LOG_LOG)); + assertNotNull(thriftValueSet.getAllOrNoneValueSet()); + assertTrue(thriftValueSet.getAllOrNoneValueSet().isAll()); + } + + @Test + public void testFromValueSetNone() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.none(HYPER_LOG_LOG)); + assertNotNull(thriftValueSet.getAllOrNoneValueSet()); + assertFalse(thriftValueSet.getAllOrNoneValueSet().isAll()); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftEquatableValueSet.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftEquatableValueSet.java new file mode 100644 index 0000000000000..068efe3ced812 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftEquatableValueSet.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftJson; +import com.facebook.presto.spi.predicate.ValueSet; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.jsonData; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.facebook.presto.type.JsonType.JSON; +import static io.airlift.slice.Slices.utf8Slice; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestPrestoThriftEquatableValueSet +{ + private static final String JSON1 = "\"key1\":\"value1\""; + private static final String JSON2 = "\"key2\":\"value2\""; + + @Test + public void testFromValueSetAll() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.all(JSON)); + assertNotNull(thriftValueSet.getEquatableValueSet()); + assertFalse(thriftValueSet.getEquatableValueSet().isWhiteList()); + assertTrue(thriftValueSet.getEquatableValueSet().getValues().isEmpty()); + } + + @Test + public void testFromValueSetNone() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.none(JSON)); + assertNotNull(thriftValueSet.getEquatableValueSet()); + assertTrue(thriftValueSet.getEquatableValueSet().isWhiteList()); + assertTrue(thriftValueSet.getEquatableValueSet().getValues().isEmpty()); + } + + @Test + public void testFromValueSetOf() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.of(JSON, utf8Slice(JSON1), utf8Slice(JSON2))); + assertNotNull(thriftValueSet.getEquatableValueSet()); + assertTrue(thriftValueSet.getEquatableValueSet().isWhiteList()); + assertEquals(thriftValueSet.getEquatableValueSet().getValues(), ImmutableList.of( + jsonData(new PrestoThriftJson(null, new int[] {JSON1.length()}, JSON1.getBytes(UTF_8))), + jsonData(new PrestoThriftJson(null, new int[] {JSON2.length()}, JSON2.getBytes(UTF_8))))); + } +} diff --git a/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftRangeValueSet.java b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftRangeValueSet.java new file mode 100644 index 0000000000000..16c82af88e929 --- /dev/null +++ b/presto-thrift-connector-api/src/test/java/com/facebook/presto/connector/thrift/api/valuesets/TestPrestoThriftRangeValueSet.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.api.valuesets; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.connector.thrift.api.datatypes.PrestoThriftBigint; +import com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftMarker; +import com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftRange; +import com.facebook.presto.spi.predicate.Range; +import com.facebook.presto.spi.predicate.ValueSet; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.bigintData; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.ABOVE; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.BELOW; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftRangeValueSet.PrestoThriftBound.EXACTLY; +import static com.facebook.presto.connector.thrift.api.valuesets.PrestoThriftValueSet.fromValueSet; +import static com.facebook.presto.spi.predicate.Range.range; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestPrestoThriftRangeValueSet +{ + @Test + public void testFromValueSetAll() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.all(BIGINT)); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(null, ABOVE), new PrestoThriftMarker(null, BELOW)))); + } + + @Test + public void testFromValueSetNone() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.none(BIGINT)); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of()); + } + + @Test + public void testFromValueSetOf() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.of(BIGINT, 1L, 2L, 3L)); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(longValue(1), EXACTLY), new PrestoThriftMarker(longValue(1), EXACTLY)), + new PrestoThriftRange(new PrestoThriftMarker(longValue(2), EXACTLY), new PrestoThriftMarker(longValue(2), EXACTLY)), + new PrestoThriftRange(new PrestoThriftMarker(longValue(3), EXACTLY), new PrestoThriftMarker(longValue(3), EXACTLY)))); + } + + @Test + public void testFromValueSetOfRangesUnbounded() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.ofRanges(Range.greaterThanOrEqual(BIGINT, 0L))); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(longValue(0), EXACTLY), new PrestoThriftMarker(null, BELOW)))); + } + + @Test + public void testFromValueSetOfRangesBounded() + throws Exception + { + PrestoThriftValueSet thriftValueSet = fromValueSet(ValueSet.ofRanges( + range(BIGINT, -10L, true, -1L, false), + range(BIGINT, -1L, false, 100L, true))); + assertNotNull(thriftValueSet.getRangeValueSet()); + assertEquals(thriftValueSet.getRangeValueSet().getRanges(), ImmutableList.of( + new PrestoThriftRange(new PrestoThriftMarker(longValue(-10), EXACTLY), new PrestoThriftMarker(longValue(-1), BELOW)), + new PrestoThriftRange(new PrestoThriftMarker(longValue(-1), ABOVE), new PrestoThriftMarker(longValue(100), EXACTLY)))); + } + + private static PrestoThriftBlock longValue(long value) + { + return bigintData(new PrestoThriftBigint(null, new long[] {value})); + } +} diff --git a/presto-thrift-connector/README.md b/presto-thrift-connector/README.md new file mode 100644 index 0000000000000..6a564e2b32a42 --- /dev/null +++ b/presto-thrift-connector/README.md @@ -0,0 +1,23 @@ +Thrift Connector +================ + +Thrift Connector makes it possible to integrate with external storage systems without a custom Presto connector implementation. + +In order to use Thrift Connector with external system you need to implement `PrestoThriftService` interface defined in `presto-thrift-connector-api` project. +Next, you configure Thrift Connector to point to a set of machines, called thrift servers, implementing it. +As part of the interface implementation thrift servers will provide metadata, splits and data. +Thrift server instances are assumed to be stateless and independent from each other. + +Using Thrift Connector over a custom Presto connector can be especially useful in the following cases. + +* Java client for a storage system is not available. +By using Thrift as transport and service definition Thrift Connector can integrate with systems written in non-Java languages. + +* Storage system's model doesn't easily map to metadata/table/row concept or there are multiple ways to do it. +For example, there are multiple ways how to map data from a key/value storage to relational representation. +Instead of supporting all of the variations in the connector this task can be moved to the external system itself. + +* You cannot or don't want to modify Presto code to add a custom connector to support your storage system. + +You can find thrift service interface that needs to be implemented together with related thrift structures in `presto-thrift-connector-api` project. +Documentation of [`PrestoThriftService`](../presto-thrift-connector-api/src/main/java/com/facebook/presto/connector/thrift/api/PrestoThriftService.java) is a good starting point. diff --git a/presto-thrift-connector/pom.xml b/presto-thrift-connector/pom.xml new file mode 100644 index 0000000000000..0d55fff7e4d1c --- /dev/null +++ b/presto-thrift-connector/pom.xml @@ -0,0 +1,156 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.180-SNAPSHOT + + + presto-thrift-connector + Presto - Thrift Connector + presto-plugin + + + ${project.parent.basedir} + + + + + com.facebook.presto + presto-thrift-connector-api + + + + com.google.guava + guava + + + + com.google.code.findbugs + annotations + + + + com.facebook.swift + swift-codec + + + + com.facebook.swift + swift-service + + + + io.airlift + bootstrap + + + + io.airlift + json + + + + io.airlift + log + + + + org.weakref + jmxutils + + + + com.google.inject + guice + + + + javax.inject + javax.inject + + + + io.airlift + configuration + + + + com.facebook.nifty + nifty-client + + + + javax.validation + validation-api + + + + io.airlift + concurrent + + + + com.facebook.presto + presto-spi + provided + + + + io.airlift + slice + provided + + + + io.airlift + units + provided + + + + org.openjdk.jol + jol-core + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + + org.testng + testng + test + + + + com.facebook.presto + presto-thrift-testing-server + test + + + + io.airlift + testing + test + + + + com.facebook.presto + presto-tests + test + + + + com.facebook.presto + presto-main + test + + + diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftColumnHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftColumnHandle.java new file mode 100644 index 0000000000000..21b3acc6b43d4 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftColumnHandle.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ThriftColumnHandle + implements ColumnHandle +{ + private final String columnName; + private final Type columnType; + private final String comment; + private final boolean hidden; + + @JsonCreator + public ThriftColumnHandle( + @JsonProperty("columnName") String columnName, + @JsonProperty("columnType") Type columnType, + @JsonProperty("comment") @Nullable String comment, + @JsonProperty("hidden") boolean hidden) + { + this.columnName = requireNonNull(columnName, "columnName is null"); + this.columnType = requireNonNull(columnType, "columnType is null"); + this.comment = comment; + this.hidden = hidden; + } + + public ThriftColumnHandle(ColumnMetadata columnMetadata) + { + this(columnMetadata.getName(), columnMetadata.getType(), columnMetadata.getComment(), columnMetadata.isHidden()); + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @JsonProperty + public Type getColumnType() + { + return columnType; + } + + @Nullable + @JsonProperty + public String getComment() + { + return comment; + } + + @JsonProperty + public boolean isHidden() + { + return hidden; + } + + public ColumnMetadata toColumnMetadata() + { + return new ColumnMetadata(columnName, columnType, comment, hidden); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ThriftColumnHandle other = (ThriftColumnHandle) obj; + return Objects.equals(this.columnName, other.columnName) && + Objects.equals(this.columnType, other.columnType) && + Objects.equals(this.comment, other.comment) && + this.hidden == other.hidden; + } + + @Override + public int hashCode() + { + return Objects.hash(columnName, columnType, comment, hidden); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("columnName", columnName) + .add("columnType", columnType) + .add("comment", comment) + .add("hidden", hidden) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnector.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnector.java new file mode 100644 index 0000000000000..01d96f40026a9 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnector.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.session.PropertyMetadata; +import com.facebook.presto.spi.transaction.IsolationLevel; +import io.airlift.bootstrap.LifeCycleManager; +import io.airlift.log.Logger; + +import javax.inject.Inject; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ThriftConnector + implements Connector +{ + private static final Logger log = Logger.get(ThriftConnector.class); + + private final LifeCycleManager lifeCycleManager; + private final ThriftMetadata metadata; + private final ThriftSplitManager splitManager; + private final ThriftPageSourceProvider pageSourceProvider; + private final ThriftSessionProperties sessionProperties; + + @Inject + public ThriftConnector( + LifeCycleManager lifeCycleManager, + ThriftMetadata metadata, + ThriftSplitManager splitManager, + ThriftPageSourceProvider pageSourceProvider, + ThriftSessionProperties sessionProperties) + { + this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); + this.sessionProperties = requireNonNull(sessionProperties, "sessionProperties is null"); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return ThriftTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return splitManager; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return pageSourceProvider; + } + + @Override + public List> getSessionProperties() + { + return sessionProperties.getSessionProperties(); + } + + @Override + public final void shutdown() + { + try { + lifeCycleManager.stop(); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + log.error(ie, "Interrupted while shutting down connector"); + } + catch (Exception e) { + log.error(e, "Error shutting down connector"); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorConfig.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorConfig.java new file mode 100644 index 0000000000000..4eb43e17e557a --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorConfig.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import io.airlift.configuration.Config; +import io.airlift.units.DataSize; +import io.airlift.units.MaxDataSize; +import io.airlift.units.MinDataSize; + +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; + +public class ThriftConnectorConfig +{ + private DataSize maxResponseSize = new DataSize(16, MEGABYTE); + private int metadataRefreshThreads = 1; + + @NotNull + @MinDataSize("1MB") + @MaxDataSize("32MB") + public DataSize getMaxResponseSize() + { + return maxResponseSize; + } + + @Config("presto-thrift.max-response-size") + public ThriftConnectorConfig setMaxResponseSize(DataSize maxResponseSize) + { + this.maxResponseSize = maxResponseSize; + return this; + } + + @Min(1) + public int getMetadataRefreshThreads() + { + return metadataRefreshThreads; + } + + @Config("presto-thrift.metadata-refresh-threads") + public ThriftConnectorConfig setMetadataRefreshThreads(int metadataRefreshThreads) + { + this.metadataRefreshThreads = metadataRefreshThreads; + return this; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorFactory.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorFactory.java new file mode 100644 index 0000000000000..d488d8fe31367 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorFactory.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.util.RebindSafeMBeanServer; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.swift.codec.guice.ThriftCodecModule; +import com.facebook.swift.service.guice.ThriftClientModule; +import com.facebook.swift.service.guice.ThriftClientStatsModule; +import com.google.inject.Injector; +import com.google.inject.Module; +import io.airlift.bootstrap.Bootstrap; +import io.airlift.json.JsonModule; +import org.weakref.jmx.guice.MBeanModule; + +import javax.management.MBeanServer; + +import java.util.Map; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.lang.management.ManagementFactory.getPlatformMBeanServer; +import static java.util.Objects.requireNonNull; + +public class ThriftConnectorFactory + implements ConnectorFactory +{ + private final String name; + private final Module locationModule; + + public ThriftConnectorFactory(String name, Module locationModule) + { + this.name = requireNonNull(name, "name is null"); + this.locationModule = requireNonNull(locationModule, "locationModule is null"); + } + + @Override + public String getName() + { + return name; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new ThriftHandleResolver(); + } + + @Override + public Connector create(String connectorId, Map config, ConnectorContext context) + { + try { + Bootstrap app = new Bootstrap( + new JsonModule(), + new MBeanModule(), + new ThriftCodecModule(), + new ThriftClientModule(), + new ThriftClientStatsModule(), + binder -> { + binder.bind(MBeanServer.class).toInstance(new RebindSafeMBeanServer(getPlatformMBeanServer())); + binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + }, + locationModule, + new ThriftModule()); + + Injector injector = app + .strictConfig() + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(ThriftConnector.class); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while creating connector", ie); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorSplit.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorSplit.java new file mode 100644 index 0000000000000..953c140b49667 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftConnectorSplit.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.HostAddress; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class ThriftConnectorSplit + implements ConnectorSplit +{ + private final PrestoThriftId splitId; + private final List addresses; + + @JsonCreator + public ThriftConnectorSplit( + @JsonProperty("splitId") PrestoThriftId splitId, + @JsonProperty("addresses") List addresses) + { + this.splitId = requireNonNull(splitId, "splitId is null"); + this.addresses = ImmutableList.copyOf(requireNonNull(addresses, "addresses is null")); + } + + @JsonProperty + public PrestoThriftId getSplitId() + { + return splitId; + } + + @Override + @JsonProperty + public List getAddresses() + { + return addresses; + } + + @Override + public Object getInfo() + { + return this; + } + + @Override + public boolean isRemotelyAccessible() + { + return true; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ThriftConnectorSplit other = (ThriftConnectorSplit) obj; + return Objects.equals(this.splitId, other.splitId) && + Objects.equals(this.addresses, other.addresses); + } + + @Override + public int hashCode() + { + return Objects.hash(splitId, addresses); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("splitId", splitId) + .add("addresses", addresses) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftErrorCode.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftErrorCode.java new file mode 100644 index 0000000000000..e4c4149c1e5f4 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftErrorCode.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ErrorCode; +import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.ErrorType; + +import static com.facebook.presto.spi.ErrorType.EXTERNAL; + +public enum ThriftErrorCode + implements ErrorCodeSupplier +{ + THRIFT_SERVICE_CONNECTION_ERROR(1, EXTERNAL), + THRIFT_SERVICE_INVALID_RESPONSE(2, EXTERNAL); + + private final ErrorCode errorCode; + + ThriftErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0105, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftHandleResolver.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftHandleResolver.java new file mode 100644 index 0000000000000..413b42020118c --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftHandleResolver.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public class ThriftHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return ThriftTableHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return ThriftTableLayoutHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return ThriftColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return ThriftConnectorSplit.class; + } + + @Override + public Class getTransactionHandleClass() + { + return ThriftTransactionHandle.class; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftMetadata.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftMetadata.java new file mode 100644 index 0000000000000..57ba6ca398c11 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftMetadata.java @@ -0,0 +1,194 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.annotations.ForMetadataRefresh; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableSchemaName; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.TableNotFoundException; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.type.TypeManager; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableList; +import io.airlift.units.Duration; + +import javax.annotation.Nonnull; +import javax.inject.Inject; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Executor; + +import static com.facebook.presto.connector.thrift.ThriftErrorCode.THRIFT_SERVICE_INVALID_RESPONSE; +import static com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName.fromSchemaTableName; +import static com.google.common.cache.CacheLoader.asyncReloading; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.function.Function.identity; + +public class ThriftMetadata + implements ConnectorMetadata +{ + private static final Duration EXPIRE_AFTER_WRITE = new Duration(10, MINUTES); + private static final Duration REFRESH_AFTER_WRITE = new Duration(2, MINUTES); + + private final PrestoThriftServiceProvider clientProvider; + private final TypeManager typeManager; + private final LoadingCache> tableCache; + + @Inject + public ThriftMetadata( + PrestoThriftServiceProvider clientProvider, + TypeManager typeManager, + @ForMetadataRefresh Executor metadataRefreshExecutor) + { + this.clientProvider = requireNonNull(clientProvider, "clientProvider is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.tableCache = CacheBuilder.newBuilder() + .expireAfterWrite(EXPIRE_AFTER_WRITE.toMillis(), MILLISECONDS) + .refreshAfterWrite(REFRESH_AFTER_WRITE.toMillis(), MILLISECONDS) + .build(asyncReloading(new CacheLoader>() + { + @Override + public Optional load(@Nonnull SchemaTableName schemaTableName) + throws Exception + { + return getTableMetadataInternal(schemaTableName); + } + }, metadataRefreshExecutor)); + } + + @Override + public List listSchemaNames(ConnectorSession session) + { + return clientProvider.runOnAnyHost(PrestoThriftService::listSchemaNames); + } + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + return tableCache.getUnchecked(tableName) + .map(ConnectorTableMetadata::getTable) + .map(ThriftTableHandle::new) + .orElse(null); + } + + @Override + public List getTableLayouts( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) + { + ThriftTableHandle tableHandle = (ThriftTableHandle) table; + ThriftTableLayoutHandle layoutHandle = new ThriftTableLayoutHandle( + tableHandle.getSchemaName(), + tableHandle.getTableName(), + desiredColumns, + constraint.getSummary()); + return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(layoutHandle), constraint.getSummary())); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + return new ConnectorTableLayout(handle); + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) + { + ThriftTableHandle handle = ((ThriftTableHandle) tableHandle); + return getTableMetadata(new SchemaTableName(handle.getSchemaName(), handle.getTableName())); + } + + @Override + public List listTables(ConnectorSession session, String schemaNameOrNull) + { + return clientProvider.runOnAnyHost(client -> client.listTables(new PrestoThriftNullableSchemaName(schemaNameOrNull))).stream() + .map(PrestoThriftSchemaTableName::toSchemaTableName) + .collect(toImmutableList()); + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return getTableMetadata(session, tableHandle).getColumns().stream().collect(toImmutableMap(ColumnMetadata::getName, ThriftColumnHandle::new)); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return ((ThriftColumnHandle) columnHandle).toColumnMetadata(); + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + return listTables(session, prefix.getSchemaName()).stream().collect(toImmutableMap(identity(), schemaTableName -> getTableMetadata(schemaTableName).getColumns())); + } + + private ConnectorTableMetadata getTableMetadata(SchemaTableName schemaTableName) + { + Optional table = tableCache.getUnchecked(schemaTableName); + if (!table.isPresent()) { + throw new TableNotFoundException(schemaTableName); + } + else { + return table.get(); + } + } + + // this method makes actual thrift request and should be called only by cache load method + private Optional getTableMetadataInternal(SchemaTableName schemaTableName) + { + requireNonNull(schemaTableName, "schemaTableName is null"); + return clientProvider.runOnAnyHost(client -> { + PrestoThriftNullableTableMetadata thriftTableMetadata = client.getTableMetadata(fromSchemaTableName(schemaTableName)); + if (thriftTableMetadata.getTableMetadata() == null) { + return Optional.empty(); + } + else { + ConnectorTableMetadata tableMetadata = thriftTableMetadata.getTableMetadata().toConnectorTableMetadata(typeManager); + if (!Objects.equals(schemaTableName, tableMetadata.getTable())) { + throw new PrestoException(THRIFT_SERVICE_INVALID_RESPONSE, "Requested and actual table names are different"); + } + return Optional.of(tableMetadata); + } + }); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftModule.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftModule.java new file mode 100644 index 0000000000000..c87dc8a7c54d2 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftModule.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.annotations.ForMetadataRefresh; +import com.facebook.presto.connector.thrift.annotations.NonRetrying; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.clientproviders.DefaultPrestoThriftServiceProvider; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.connector.thrift.clientproviders.RetryingPrestoThriftServiceProvider; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Scopes; + +import javax.inject.Singleton; + +import java.util.concurrent.Executor; + +import static com.facebook.swift.service.guice.ThriftClientBinder.thriftClientBinder; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.configuration.ConfigBinder.configBinder; +import static java.util.concurrent.Executors.newFixedThreadPool; + +public class ThriftModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ThriftConnector.class).in(Scopes.SINGLETON); + thriftClientBinder(binder).bindThriftClient(PrestoThriftService.class); + binder.bind(ThriftMetadata.class).in(Scopes.SINGLETON); + binder.bind(ThriftSplitManager.class).in(Scopes.SINGLETON); + binder.bind(ThriftPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(PrestoThriftServiceProvider.class).to(RetryingPrestoThriftServiceProvider.class).in(Scopes.SINGLETON); + binder.bind(PrestoThriftServiceProvider.class).annotatedWith(NonRetrying.class).to(DefaultPrestoThriftServiceProvider.class).in(Scopes.SINGLETON); + configBinder(binder).bindConfig(ThriftConnectorConfig.class); + binder.bind(ThriftSessionProperties.class).in(Scopes.SINGLETON); + } + + @Provides + @Singleton + @ForMetadataRefresh + public Executor createMetadataRefreshExecutor(ThriftConnectorConfig config) + { + return newFixedThreadPool(config.getMetadataRefreshThreads(), daemonThreadsNamed("metadata-refresh-%s")); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSource.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSource.java new file mode 100644 index 0000000000000..77f9287152684 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSource.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftPageResult; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.type.Type; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.concurrent.MoreFutures.toCompletableFuture; +import static java.util.Objects.requireNonNull; + +public class ThriftPageSource + implements ConnectorPageSource +{ + private final PrestoThriftId splitId; + private final PrestoThriftService client; + private final List columnNames; + private final List columnTypes; + private final long maxBytesPerResponse; + private final AtomicLong readTimeNanos = new AtomicLong(0); + + private PrestoThriftId nextToken; + private boolean firstCall = true; + private CompletableFuture future; + private long completedBytes; + + public ThriftPageSource( + PrestoThriftServiceProvider clientProvider, + ThriftConnectorSplit split, + List columns, + long maxBytesPerResponse) + { + // init columns + requireNonNull(columns, "columns is null"); + ImmutableList.Builder columnNames = new ImmutableList.Builder<>(); + ImmutableList.Builder columnTypes = new ImmutableList.Builder<>(); + for (ColumnHandle columnHandle : columns) { + ThriftColumnHandle thriftColumnHandle = (ThriftColumnHandle) columnHandle; + columnNames.add(thriftColumnHandle.getColumnName()); + columnTypes.add(thriftColumnHandle.getColumnType()); + } + this.columnNames = columnNames.build(); + this.columnTypes = columnTypes.build(); + + // this parameter is read from config, so it should be checked by config validation + // however, here it's a raw constructor parameter, so adding this safety check + checkArgument(maxBytesPerResponse > 0, "maxBytesPerResponse is zero or negative"); + this.maxBytesPerResponse = maxBytesPerResponse; + + // init split + requireNonNull(split, "split is null"); + this.splitId = split.getSplitId(); + + // init client + requireNonNull(clientProvider, "clientProvider is null"); + if (split.getAddresses().isEmpty()) { + this.client = clientProvider.anyHostClient(); + } + else { + this.client = clientProvider.selectedHostClient(split.getAddresses()); + } + } + + @Override + public long getTotalBytes() + { + return 0; + } + + @Override + public long getCompletedBytes() + { + return completedBytes; + } + + @Override + public long getReadTimeNanos() + { + return readTimeNanos.get(); + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public boolean isFinished() + { + return !firstCall && !canGetMoreData(nextToken); + } + + @Override + public Page getNextPage() + { + if (future == null) { + // no data request in progress + if (firstCall || canGetMoreData(nextToken)) { + // no data in the current batch, but can request more; will send a request + future = sendDataRequestInternal(); + } + return null; + } + + if (!future.isDone()) { + // data request is in progress + return null; + } + + // response for data request is ready + Page result = processBatch(getFutureValue(future)); + + // immediately try sending a new request + if (canGetMoreData(nextToken)) { + future = sendDataRequestInternal(); + } + else { + future = null; + } + + return result; + } + + private static boolean canGetMoreData(PrestoThriftId nextToken) + { + return nextToken != null; + } + + private CompletableFuture sendDataRequestInternal() + { + long start = System.nanoTime(); + ListenableFuture rowsBatchFuture = client.getRows( + splitId, + columnNames, + maxBytesPerResponse, + new PrestoThriftNullableToken(nextToken)); + rowsBatchFuture.addListener(() -> readTimeNanos.addAndGet(System.nanoTime() - start), directExecutor()); + return toCompletableFuture(nonCancellationPropagating(rowsBatchFuture)); + } + + private Page processBatch(PrestoThriftPageResult rowsBatch) + { + firstCall = false; + nextToken = rowsBatch.getNextToken(); + Page page = rowsBatch.toPage(columnTypes); + if (page != null) { + completedBytes += page.getSizeInBytes(); + } + return page; + } + + @Override + public CompletableFuture isBlocked() + { + return future == null || future.isDone() ? NOT_BLOCKED : future; + } + + @Override + public void close() + throws IOException + { + if (future != null) { + future.cancel(true); + } + client.close(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSourceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSourceProvider.java new file mode 100644 index 0000000000000..4a63442cc58ef --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPageSourceProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +import javax.inject.Inject; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ThriftPageSourceProvider + implements ConnectorPageSourceProvider +{ + private final PrestoThriftServiceProvider clientProvider; + private final long maxBytesPerResponse; + + @Inject + public ThriftPageSourceProvider(PrestoThriftServiceProvider clientProvider, ThriftConnectorConfig config) + { + this.clientProvider = requireNonNull(clientProvider, "clientProvider is null"); + this.maxBytesPerResponse = requireNonNull(config, "config is null").getMaxResponseSize().toBytes(); + } + + @Override + public ConnectorPageSource createPageSource( + ConnectorTransactionHandle transactionHandle, + ConnectorSession session, + ConnectorSplit split, + List columns) + { + return new ThriftPageSource(clientProvider, (ThriftConnectorSplit) split, columns, maxBytesPerResponse); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPlugin.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPlugin.java new file mode 100644 index 0000000000000..529a912dd9ab6 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPlugin.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; + +import java.util.List; +import java.util.ServiceLoader; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class ThriftPlugin + implements Plugin +{ + private final String name; + private final Module locationModule; + + public ThriftPlugin() + { + this(getPluginInfo()); + } + + private ThriftPlugin(ThriftPluginInfo info) + { + this(info.getName(), info.getLocationModule()); + } + + public ThriftPlugin(String name, Module locationModule) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.locationModule = requireNonNull(locationModule, "locationModule is null"); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(new ThriftConnectorFactory(name, locationModule)); + } + + private static ThriftPluginInfo getPluginInfo() + { + ClassLoader classLoader = ThriftPlugin.class.getClassLoader(); + ServiceLoader loader = ServiceLoader.load(ThriftPluginInfo.class, classLoader); + List list = ImmutableList.copyOf(loader); + return list.isEmpty() ? new ThriftPluginInfo() : getOnlyElement(list); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPluginInfo.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPluginInfo.java new file mode 100644 index 0000000000000..e8e193661c31b --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftPluginInfo.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.location.StaticLocationModule; +import com.google.inject.Module; + +public class ThriftPluginInfo +{ + public String getName() + { + return "presto-thrift"; + } + + public Module getLocationModule() + { + return new StaticLocationModule(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSessionProperties.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSessionProperties.java new file mode 100644 index 0000000000000..f3c036b01bd21 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSessionProperties.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.session.PropertyMetadata; +import com.google.common.collect.ImmutableList; + +import javax.inject.Inject; + +import java.util.List; + +/** + * Internal session properties are those defined by the connector itself. + * These properties control certain aspects of connector's work. + */ +public final class ThriftSessionProperties +{ + private final List> sessionProperties; + + @Inject + public ThriftSessionProperties(ThriftConnectorConfig config) + { + sessionProperties = ImmutableList.of(); + } + + public List> getSessionProperties() + { + return sessionProperties; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSplitManager.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSplitManager.java new file mode 100644 index 0000000000000..a36f33de8e475 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftSplitManager.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.connector.thrift.api.PrestoThriftDomain; +import com.facebook.presto.connector.thrift.api.PrestoThriftHostAddress; +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableColumnSet; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplit; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplitBatch; +import com.facebook.presto.connector.thrift.api.PrestoThriftTupleDomain; +import com.facebook.presto.connector.thrift.clientproviders.PrestoThriftServiceProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; + +import javax.annotation.concurrent.NotThreadSafe; +import javax.inject.Inject; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftDomain.fromDomain; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.concurrent.MoreFutures.toCompletableFuture; +import static java.util.Objects.requireNonNull; + +public class ThriftSplitManager + implements ConnectorSplitManager +{ + private final PrestoThriftServiceProvider clientProvider; + + @Inject + public ThriftSplitManager(PrestoThriftServiceProvider clientProvider) + { + this.clientProvider = requireNonNull(clientProvider, "clientProvider is null"); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout) + { + ThriftTableLayoutHandle layoutHandle = (ThriftTableLayoutHandle) layout; + return new ThriftSplitSource( + clientProvider.anyHostClient(), + new PrestoThriftSchemaTableName(layoutHandle.getSchemaName(), layoutHandle.getTableName()), + layoutHandle.getColumns().map(ThriftSplitManager::columnNames), + tupleDomainToThriftTupleDomain(layoutHandle.getConstraint())); + } + + private static Set columnNames(Set columns) + { + return columns.stream() + .map(ThriftColumnHandle.class::cast) + .map(ThriftColumnHandle::getColumnName) + .collect(toImmutableSet()); + } + + private static PrestoThriftTupleDomain tupleDomainToThriftTupleDomain(TupleDomain tupleDomain) + { + if (!tupleDomain.getDomains().isPresent()) { + return new PrestoThriftTupleDomain(null); + } + Map thriftDomains = tupleDomain.getDomains().get() + .entrySet().stream() + .collect(toImmutableMap( + entry -> ((ThriftColumnHandle) entry.getKey()).getColumnName(), + entry -> fromDomain(entry.getValue()))); + return new PrestoThriftTupleDomain(thriftDomains); + } + + @NotThreadSafe + private static class ThriftSplitSource + implements ConnectorSplitSource + { + private final PrestoThriftService client; + private final PrestoThriftSchemaTableName schemaTableName; + private final Optional> columnNames; + private final PrestoThriftTupleDomain constraint; + + // the code assumes getNextBatch is called by a single thread + + private final AtomicBoolean hasMoreData; + private final AtomicReference nextToken; + private final AtomicReference> future; + + public ThriftSplitSource( + PrestoThriftService client, + PrestoThriftSchemaTableName schemaTableName, + Optional> columnNames, + PrestoThriftTupleDomain constraint) + { + this.client = requireNonNull(client, "client is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.columnNames = requireNonNull(columnNames, "columnNames is null"); + this.constraint = requireNonNull(constraint, "constraint is null"); + this.nextToken = new AtomicReference<>(null); + this.hasMoreData = new AtomicBoolean(true); + this.future = new AtomicReference<>(null); + } + + /** + * Returns a future with a list of splits. + * This method is assumed to be called in a single-threaded way. + * It can be called by multiple threads, but only if the previous call finished. + */ + @Override + public CompletableFuture> getNextBatch(int maxSize) + { + checkState(future.get() == null || future.get().isDone(), "previous batch not completed"); + checkState(hasMoreData.get(), "this method cannot be invoked when there's no more data"); + PrestoThriftId currentToken = nextToken.get(); + ListenableFuture splitsFuture = client.getSplits( + schemaTableName, + new PrestoThriftNullableColumnSet(columnNames.orElse(null)), + constraint, + maxSize, + new PrestoThriftNullableToken(currentToken)); + ListenableFuture> resultFuture = Futures.transform( + splitsFuture, + batch -> { + requireNonNull(batch, "batch is null"); + List splits = batch.getSplits().stream() + .map(ThriftSplitSource::toConnectorSplit) + .collect(toImmutableList()); + checkState(nextToken.compareAndSet(currentToken, batch.getNextToken())); + checkState(hasMoreData.compareAndSet(true, nextToken.get() != null)); + return splits; + }); + future.set(resultFuture); + return toCompletableFuture(resultFuture); + } + + @Override + public boolean isFinished() + { + return !hasMoreData.get(); + } + + @Override + public void close() + { + Future currentFuture = future.getAndSet(null); + if (currentFuture != null) { + currentFuture.cancel(true); + } + client.close(); + } + + private static ThriftConnectorSplit toConnectorSplit(PrestoThriftSplit thriftSplit) + { + return new ThriftConnectorSplit( + thriftSplit.getSplitId(), + toHostAddressList(thriftSplit.getHosts())); + } + + private static List toHostAddressList(List hosts) + { + return hosts.stream().map(PrestoThriftHostAddress::toHostAddress).collect(toImmutableList()); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableHandle.java new file mode 100644 index 0000000000000..c47db0adff710 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableHandle.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.SchemaTableName; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class ThriftTableHandle + implements ConnectorTableHandle +{ + private final String schemaName; + private final String tableName; + + @JsonCreator + public ThriftTableHandle( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + } + + public ThriftTableHandle(SchemaTableName schemaTableName) + { + this(schemaTableName.getSchemaName(), schemaTableName.getTableName()); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ThriftTableHandle other = (ThriftTableHandle) obj; + return Objects.equals(this.schemaName, other.schemaName) && + Objects.equals(this.tableName, other.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaName, tableName); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", getSchemaName()) + .add("tableName", getTableName()) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableLayoutHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableLayoutHandle.java new file mode 100644 index 0000000000000..a1f15667c53cb --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTableLayoutHandle.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.predicate.TupleDomain; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableSet; + +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class ThriftTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final String schemaName; + private final String tableName; + private final Optional> columns; + private final TupleDomain constraint; + + @JsonCreator + public ThriftTableLayoutHandle( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("columns") Optional> columns, + @JsonProperty("constraint") TupleDomain constraint) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.columns = requireNonNull(columns, "columns is null").map(ImmutableSet::copyOf); + this.constraint = requireNonNull(constraint, "constraint is null"); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public Optional> getColumns() + { + return columns; + } + + @JsonProperty + public TupleDomain getConstraint() + { + return constraint; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ThriftTableLayoutHandle other = (ThriftTableLayoutHandle) o; + return schemaName.equals(other.schemaName) + && tableName.equals(other.tableName) + && columns.equals(other.columns) + && constraint.equals(other.constraint); + } + + @Override + public int hashCode() + { + return Objects.hash(schemaName, tableName, columns, constraint); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("schemaName", schemaName) + .add("tableName", tableName) + .add("columns", columns) + .add("constraint", constraint) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTransactionHandle.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTransactionHandle.java new file mode 100644 index 0000000000000..9ab3ef68d8210 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/ThriftTransactionHandle.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum ThriftTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ForCassandra.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/ForMetadataRefresh.java similarity index 87% rename from presto-cassandra/src/main/java/com/facebook/presto/cassandra/ForCassandra.java rename to presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/ForMetadataRefresh.java index 704b2a74c048d..b2fff95db23a6 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ForCassandra.java +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/ForMetadataRefresh.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.cassandra; +package com.facebook.presto.connector.thrift.annotations; import javax.inject.Qualifier; @@ -24,8 +24,8 @@ import static java.lang.annotation.RetentionPolicy.RUNTIME; @Retention(RUNTIME) -@Target({FIELD, PARAMETER, METHOD}) +@Target({PARAMETER, METHOD, FIELD}) @Qualifier -public @interface ForCassandra +public @interface ForMetadataRefresh { } diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/NonRetrying.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/NonRetrying.java new file mode 100644 index 0000000000000..41c908eafccba --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/annotations/NonRetrying.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.annotations; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@BindingAnnotation +@Target(PARAMETER) +@Retention(RUNTIME) +public @interface NonRetrying +{ +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/DefaultPrestoThriftServiceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/DefaultPrestoThriftServiceProvider.java new file mode 100644 index 0000000000000..5170998530e64 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/DefaultPrestoThriftServiceProvider.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.clientproviders; + +import com.facebook.nifty.client.FramedClientConnector; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.location.HostLocationProvider; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.PrestoException; +import com.facebook.swift.service.ThriftClient; +import com.google.common.net.HostAndPort; +import io.airlift.units.Duration; + +import javax.inject.Inject; + +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static com.facebook.presto.connector.thrift.ThriftErrorCode.THRIFT_SERVICE_CONNECTION_ERROR; +import static java.util.Objects.requireNonNull; + +public class DefaultPrestoThriftServiceProvider + implements PrestoThriftServiceProvider +{ + private final ThriftClient thriftClient; + private final HostLocationProvider locationProvider; + private final long thriftConnectTimeoutMs; + + @Inject + public DefaultPrestoThriftServiceProvider(ThriftClient thriftClient, HostLocationProvider locationProvider) + { + this.thriftClient = requireNonNull(thriftClient, "thriftClient is null"); + this.locationProvider = requireNonNull(locationProvider, "locationProvider is null"); + this.thriftConnectTimeoutMs = Duration.valueOf(thriftClient.getConnectTimeout()).toMillis(); + } + + @Override + public PrestoThriftService anyHostClient() + { + return connectTo(locationProvider.getAnyHost()); + } + + @Override + public PrestoThriftService selectedHostClient(List hosts) + { + return connectTo(locationProvider.getAnyOf(hosts)); + } + + private PrestoThriftService connectTo(HostAddress host) + { + try { + return thriftClient.open(new FramedClientConnector(HostAndPort.fromParts(host.getHostText(), host.getPort()))) + .get(thriftConnectTimeoutMs, TimeUnit.MILLISECONDS); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while connecting to thrift host at " + host, e); + } + catch (ExecutionException | TimeoutException e) { + throw new PrestoException(THRIFT_SERVICE_CONNECTION_ERROR, "Cannot connect to thrift host at " + host, e); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/PrestoThriftServiceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/PrestoThriftServiceProvider.java new file mode 100644 index 0000000000000..c6e0c629063f1 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/PrestoThriftServiceProvider.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.clientproviders; + +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.spi.HostAddress; + +import java.util.List; +import java.util.function.Function; + +public interface PrestoThriftServiceProvider +{ + PrestoThriftService anyHostClient(); + + PrestoThriftService selectedHostClient(List hosts); + + default V runOnAnyHost(Function call) + { + try (PrestoThriftService client = anyHostClient()) { + return call.apply(client); + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/RetryingPrestoThriftServiceProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/RetryingPrestoThriftServiceProvider.java new file mode 100644 index 0000000000000..cfc1ca628055c --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/clientproviders/RetryingPrestoThriftServiceProvider.java @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.clientproviders; + +import com.facebook.presto.connector.thrift.annotations.NonRetrying; +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableColumnSet; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableSchemaName; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftPageResult; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.api.PrestoThriftServiceException; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplitBatch; +import com.facebook.presto.connector.thrift.api.PrestoThriftTupleDomain; +import com.facebook.presto.connector.thrift.util.RetryDriver; +import com.facebook.presto.spi.HostAddress; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.log.Logger; +import io.airlift.units.Duration; + +import javax.annotation.concurrent.NotThreadSafe; +import javax.inject.Inject; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +public class RetryingPrestoThriftServiceProvider + implements PrestoThriftServiceProvider +{ + private static final Logger log = Logger.get(RetryingPrestoThriftServiceProvider.class); + private final PrestoThriftServiceProvider original; + private final RetryDriver retry; + + @Inject + public RetryingPrestoThriftServiceProvider(@NonRetrying PrestoThriftServiceProvider original) + { + this.original = requireNonNull(original, "original is null"); + retry = RetryDriver.retry() + .maxAttempts(5) + .stopRetryingWhen(e -> e instanceof PrestoThriftServiceException && !((PrestoThriftServiceException) e).isRetryable()) + .exponentialBackoff( + new Duration(10, TimeUnit.MILLISECONDS), + new Duration(20, TimeUnit.MILLISECONDS), + new Duration(30, TimeUnit.SECONDS), + 1.5); + } + + @Override + public PrestoThriftService anyHostClient() + { + return new RetryingService(original::anyHostClient, retry); + } + + @Override + public PrestoThriftService selectedHostClient(List hosts) + { + return new RetryingService(() -> original.selectedHostClient(hosts), retry); + } + + @NotThreadSafe + private static final class RetryingService + implements PrestoThriftService + { + private final Supplier clientSupplier; + private final RetryDriver retry; + private PrestoThriftService client; + + public RetryingService(Supplier clientSupplier, RetryDriver retry) + { + this.clientSupplier = requireNonNull(clientSupplier, "clientSupplier is null"); + this.retry = retry.onRetry(this::close); + } + + private PrestoThriftService getClient() + { + if (client != null) { + return client; + } + client = clientSupplier.get(); + return client; + } + + @Override + public List listSchemaNames() + { + return retry.run("listSchemaNames", () -> getClient().listSchemaNames()); + } + + @Override + public List listTables(PrestoThriftNullableSchemaName schemaNameOrNull) + { + return retry.run("listTables", () -> getClient().listTables(schemaNameOrNull)); + } + + @Override + public PrestoThriftNullableTableMetadata getTableMetadata(PrestoThriftSchemaTableName schemaTableName) + { + return retry.run("getTableMetadata", () -> getClient().getTableMetadata(schemaTableName)); + } + + @Override + public ListenableFuture getSplits( + PrestoThriftSchemaTableName schemaTableName, + PrestoThriftNullableColumnSet desiredColumns, + PrestoThriftTupleDomain outputConstraint, + int maxSplitCount, + PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException + { + return retry.run("getSplits", () -> getClient().getSplits(schemaTableName, desiredColumns, outputConstraint, maxSplitCount, nextToken)); + } + + @Override + public ListenableFuture getRows(PrestoThriftId splitId, List columns, long maxBytes, PrestoThriftNullableToken nextToken) + { + return retry.run("getRows", () -> getClient().getRows(splitId, columns, maxBytes, nextToken)); + } + + @Override + public void close() + { + if (client == null) { + return; + } + try { + client.close(); + } + catch (Exception e) { + log.warn("Error closing client", e); + } + client = null; + } + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostList.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostList.java new file mode 100644 index 0000000000000..95c3d09f258c5 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostList.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; +import com.google.common.base.Joiner; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Arrays.asList; +import static java.util.Objects.requireNonNull; + +public final class HostList +{ + private final List hosts; + + private HostList(List hosts) + { + this.hosts = ImmutableList.copyOf(requireNonNull(hosts, "hosts is null")); + } + + // needed for automatic config parsing + @SuppressWarnings("unused") + public static HostList fromString(String hosts) + { + return new HostList(Splitter.on(',').trimResults().omitEmptyStrings().splitToList(hosts).stream().map(HostAddress::fromString).collect(toImmutableList())); + } + + public static HostList of(HostAddress... hosts) + { + return new HostList(asList(hosts)); + } + + public static HostList fromList(List hosts) + { + return new HostList(hosts); + } + + public List getHosts() + { + return hosts; + } + + public String stringValue() + { + return Joiner.on(',').join(hosts); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + HostList hostList = (HostList) o; + return hosts.equals(hostList.hosts); + } + + @Override + public int hashCode() + { + return hosts.hashCode(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("hosts", hosts) + .toString(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostLocationProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostLocationProvider.java new file mode 100644 index 0000000000000..143d518d41e48 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/HostLocationProvider.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; + +import java.util.List; + +public interface HostLocationProvider +{ + HostAddress getAnyHost(); + + HostAddress getAnyOf(List hosts); +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationConfig.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationConfig.java new file mode 100644 index 0000000000000..c252f17c7fa69 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationConfig.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import io.airlift.configuration.Config; + +import javax.validation.constraints.NotNull; + +public class StaticLocationConfig +{ + private HostList hosts; + + @NotNull + public HostList getHosts() + { + return hosts; + } + + @Config("static-location.hosts") + public StaticLocationConfig setHosts(HostList hosts) + { + this.hosts = hosts; + return this; + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationModule.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationModule.java new file mode 100644 index 0000000000000..7b012fdf9b040 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationModule.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static io.airlift.configuration.ConfigBinder.configBinder; + +public class StaticLocationModule + implements Module +{ + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(StaticLocationConfig.class); + binder.bind(HostLocationProvider.class).to(StaticLocationProvider.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationProvider.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationProvider.java new file mode 100644 index 0000000000000..0346671897712 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/location/StaticLocationProvider.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; + +import javax.inject.Inject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class StaticLocationProvider + implements HostLocationProvider +{ + private final List hosts; + private final AtomicInteger index = new AtomicInteger(0); + + @Inject + public StaticLocationProvider(StaticLocationConfig config) + { + requireNonNull(config, "config is null"); + List hosts = config.getHosts().getHosts(); + checkArgument(!hosts.isEmpty(), "hosts is empty"); + this.hosts = new ArrayList<>(hosts); + Collections.shuffle(this.hosts); + } + + /** + * Provides the next host from a configured list of hosts in a round-robin fashion. + */ + @Override + public HostAddress getAnyHost() + { + return hosts.get(index.getAndUpdate(this::next)); + } + + @Override + public HostAddress getAnyOf(List requestedHosts) + { + checkArgument(requestedHosts != null && !requestedHosts.isEmpty(), "requestedHosts is null or empty"); + return requestedHosts.get(ThreadLocalRandom.current().nextInt(requestedHosts.size())); + } + + private int next(int x) + { + return (x + 1) % hosts.size(); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RebindSafeMBeanServer.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RebindSafeMBeanServer.java new file mode 100644 index 0000000000000..f8fec2de99f88 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RebindSafeMBeanServer.java @@ -0,0 +1,335 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.util; + +import io.airlift.log.Logger; + +import javax.annotation.concurrent.ThreadSafe; +import javax.management.Attribute; +import javax.management.AttributeList; +import javax.management.AttributeNotFoundException; +import javax.management.InstanceAlreadyExistsException; +import javax.management.InstanceNotFoundException; +import javax.management.IntrospectionException; +import javax.management.InvalidAttributeValueException; +import javax.management.ListenerNotFoundException; +import javax.management.MBeanException; +import javax.management.MBeanInfo; +import javax.management.MBeanRegistrationException; +import javax.management.MBeanServer; +import javax.management.NotCompliantMBeanException; +import javax.management.NotificationFilter; +import javax.management.NotificationListener; +import javax.management.ObjectInstance; +import javax.management.ObjectName; +import javax.management.OperationsException; +import javax.management.QueryExp; +import javax.management.ReflectionException; +import javax.management.loading.ClassLoaderRepository; + +import java.io.ObjectInputStream; +import java.util.Set; + +// TODO: move this to airlift or jmxutils + +/** + * MBeanServer wrapper that a ignores calls to registerMBean when there is already + * a MBean registered with the specified object name. + */ +@ThreadSafe +public class RebindSafeMBeanServer + implements MBeanServer +{ + private static final Logger log = Logger.get(RebindSafeMBeanServer.class); + + private final MBeanServer mbeanServer; + + public RebindSafeMBeanServer(MBeanServer mbeanServer) + { + this.mbeanServer = mbeanServer; + } + + /** + * Delegates to the wrapped mbean server, but if a mbean is already registered + * with the specified name, the existing instance is returned. + */ + @Override + public ObjectInstance registerMBean(Object object, ObjectName name) + throws MBeanRegistrationException, NotCompliantMBeanException + { + while (true) { + try { + // try to register the mbean + return mbeanServer.registerMBean(object, name); + } + catch (InstanceAlreadyExistsException ignored) { + } + + try { + // a mbean is already installed, try to return the already registered instance + ObjectInstance objectInstance = mbeanServer.getObjectInstance(name); + log.debug("%s already bound to %s", name, objectInstance); + return objectInstance; + } + catch (InstanceNotFoundException ignored) { + // the mbean was removed before we could get the reference + // start the whole process over again + } + } + } + + @Override + public void unregisterMBean(ObjectName name) + throws InstanceNotFoundException, MBeanRegistrationException + { + mbeanServer.unregisterMBean(name); + } + + @Override + public ObjectInstance getObjectInstance(ObjectName name) + throws InstanceNotFoundException + { + return mbeanServer.getObjectInstance(name); + } + + @Override + public Set queryMBeans(ObjectName name, QueryExp query) + { + return mbeanServer.queryMBeans(name, query); + } + + @Override + public Set queryNames(ObjectName name, QueryExp query) + { + return mbeanServer.queryNames(name, query); + } + + @Override + public boolean isRegistered(ObjectName name) + { + return mbeanServer.isRegistered(name); + } + + @Override + public Integer getMBeanCount() + { + return mbeanServer.getMBeanCount(); + } + + @Override + public Object getAttribute(ObjectName name, String attribute) + throws MBeanException, AttributeNotFoundException, InstanceNotFoundException, ReflectionException + { + return mbeanServer.getAttribute(name, attribute); + } + + @Override + public AttributeList getAttributes(ObjectName name, String[] attributes) + throws InstanceNotFoundException, ReflectionException + { + return mbeanServer.getAttributes(name, attributes); + } + + @Override + public void setAttribute(ObjectName name, Attribute attribute) + throws InstanceNotFoundException, AttributeNotFoundException, InvalidAttributeValueException, MBeanException, ReflectionException + { + mbeanServer.setAttribute(name, attribute); + } + + @Override + public AttributeList setAttributes(ObjectName name, AttributeList attributes) + throws InstanceNotFoundException, ReflectionException + { + return mbeanServer.setAttributes(name, attributes); + } + + @Override + public Object invoke(ObjectName name, String operationName, Object[] params, String[] signature) + throws InstanceNotFoundException, MBeanException, ReflectionException + { + return mbeanServer.invoke(name, operationName, params, signature); + } + + @Override + public String getDefaultDomain() + { + return mbeanServer.getDefaultDomain(); + } + + @Override + public String[] getDomains() + { + return mbeanServer.getDomains(); + } + + @Override + public void addNotificationListener(ObjectName name, NotificationListener listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException + { + mbeanServer.addNotificationListener(name, listener, filter, context); + } + + @Override + public void addNotificationListener(ObjectName name, ObjectName listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException + { + mbeanServer.addNotificationListener(name, listener, filter, context); + } + + @Override + public void removeNotificationListener(ObjectName name, ObjectName listener) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener); + } + + @Override + public void removeNotificationListener(ObjectName name, ObjectName listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener, filter, context); + } + + @Override + public void removeNotificationListener(ObjectName name, NotificationListener listener) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener); + } + + @Override + public void removeNotificationListener(ObjectName name, NotificationListener listener, NotificationFilter filter, Object context) + throws InstanceNotFoundException, ListenerNotFoundException + { + mbeanServer.removeNotificationListener(name, listener, filter, context); + } + + @Override + public MBeanInfo getMBeanInfo(ObjectName name) + throws InstanceNotFoundException, IntrospectionException, ReflectionException + { + return mbeanServer.getMBeanInfo(name); + } + + @Override + public boolean isInstanceOf(ObjectName name, String className) + throws InstanceNotFoundException + { + return mbeanServer.isInstanceOf(name, className); + } + + @Override + public Object instantiate(String className) + throws ReflectionException, MBeanException + { + return mbeanServer.instantiate(className); + } + + @Override + public Object instantiate(String className, ObjectName loaderName) + throws ReflectionException, MBeanException, InstanceNotFoundException + { + return mbeanServer.instantiate(className, loaderName); + } + + @Override + public Object instantiate(String className, Object[] params, String[] signature) + throws ReflectionException, MBeanException + { + return mbeanServer.instantiate(className, params, signature); + } + + @Override + public Object instantiate(String className, ObjectName loaderName, Object[] params, String[] signature) + throws ReflectionException, MBeanException, InstanceNotFoundException + { + return mbeanServer.instantiate(className, loaderName, params, signature); + } + + @SuppressWarnings("deprecation") + @Override + @Deprecated + public ObjectInputStream deserialize(ObjectName name, byte[] data) + throws OperationsException + { + return mbeanServer.deserialize(name, data); + } + + @SuppressWarnings("deprecation") + @Override + @Deprecated + public ObjectInputStream deserialize(String className, byte[] data) + throws OperationsException, ReflectionException + { + return mbeanServer.deserialize(className, data); + } + + @SuppressWarnings("deprecation") + @Override + @Deprecated + public ObjectInputStream deserialize(String className, ObjectName loaderName, byte[] data) + throws OperationsException, ReflectionException + { + return mbeanServer.deserialize(className, loaderName, data); + } + + @Override + public ClassLoader getClassLoaderFor(ObjectName mbeanName) + throws InstanceNotFoundException + { + return mbeanServer.getClassLoaderFor(mbeanName); + } + + @Override + public ClassLoader getClassLoader(ObjectName loaderName) + throws InstanceNotFoundException + { + return mbeanServer.getClassLoader(loaderName); + } + + @Override + public ClassLoaderRepository getClassLoaderRepository() + { + return mbeanServer.getClassLoaderRepository(); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException + { + return mbeanServer.createMBean(className, name); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name, ObjectName loaderName) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException, InstanceNotFoundException + { + return mbeanServer.createMBean(className, name, loaderName); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name, Object[] params, String[] signature) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException + { + return mbeanServer.createMBean(className, name, params, signature); + } + + @Override + public ObjectInstance createMBean(String className, ObjectName name, ObjectName loaderName, Object[] params, String[] signature) + throws ReflectionException, InstanceAlreadyExistsException, MBeanException, NotCompliantMBeanException, InstanceNotFoundException + { + return mbeanServer.createMBean(className, name, loaderName, params, signature); + } +} diff --git a/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RetryDriver.java b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RetryDriver.java new file mode 100644 index 0000000000000..6559403dcfd18 --- /dev/null +++ b/presto-thrift-connector/src/main/java/com/facebook/presto/connector/thrift/util/RetryDriver.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.util; + +import io.airlift.log.Logger; +import io.airlift.units.Duration; + +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.Predicate; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static java.util.Objects.requireNonNull; + +public class RetryDriver +{ + private static final Logger log = Logger.get(RetryDriver.class); + private static final int DEFAULT_RETRY_ATTEMPTS = 10; + private static final Duration DEFAULT_SLEEP_TIME = Duration.valueOf("1s"); + private static final Duration DEFAULT_MAX_RETRY_TIME = Duration.valueOf("30s"); + private static final double DEFAULT_SCALE_FACTOR = 2.0; + + private final int maxAttempts; + private final Duration minSleepTime; + private final Duration maxSleepTime; + private final double scaleFactor; + private final Duration maxRetryTime; + private final Optional retryRunnable; + private final Predicate stopRetrying; + private final Function classifier; + + private RetryDriver( + int maxAttempts, + Duration minSleepTime, + Duration maxSleepTime, + double scaleFactor, + Duration maxRetryTime, + Optional retryRunnable, + Predicate stopRetrying, + Function classifier) + { + this.maxAttempts = maxAttempts; + this.minSleepTime = minSleepTime; + this.maxSleepTime = maxSleepTime; + this.scaleFactor = scaleFactor; + this.maxRetryTime = maxRetryTime; + this.retryRunnable = retryRunnable; + this.stopRetrying = stopRetrying; + this.classifier = classifier; + } + + private RetryDriver() + { + this(DEFAULT_RETRY_ATTEMPTS, + DEFAULT_SLEEP_TIME, + DEFAULT_SLEEP_TIME, + DEFAULT_SCALE_FACTOR, + DEFAULT_MAX_RETRY_TIME, + Optional.empty(), + e -> false, + Function.identity()); + } + + public static RetryDriver retry() + { + return new RetryDriver(); + } + + public final RetryDriver maxAttempts(int maxAttempts) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public final RetryDriver exponentialBackoff(Duration minSleepTime, Duration maxSleepTime, Duration maxRetryTime, double scaleFactor) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public final RetryDriver onRetry(Runnable retryRunnable) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, Optional.ofNullable(retryRunnable), stopRetrying, classifier); + } + + public RetryDriver stopRetryingWhen(Predicate stopRetrying) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public RetryDriver withClassifier(Function classifier) + { + return new RetryDriver(maxAttempts, minSleepTime, maxSleepTime, scaleFactor, maxRetryTime, retryRunnable, stopRetrying, classifier); + } + + public V run(String callableName, Callable callable) + { + requireNonNull(callableName, "callableName is null"); + requireNonNull(callable, "callable is null"); + + long startTime = System.nanoTime(); + int attempt = 0; + while (true) { + attempt++; + + if (attempt > 1) { + retryRunnable.ifPresent(Runnable::run); + } + + try { + return callable.call(); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw propagate(ie); + } + catch (Exception e) { + if (stopRetrying.test(e)) { + throw propagate(e); + } + if (attempt >= maxAttempts || Duration.nanosSince(startTime).compareTo(maxRetryTime) >= 0) { + throw propagate(e); + } + log.warn("Failed on executing %s with attempt %d, will retry. Exception: %s", callableName, attempt, e.getMessage()); + + int delayInMs = (int) Math.min(minSleepTime.toMillis() * Math.pow(scaleFactor, attempt - 1), maxSleepTime.toMillis()); + int jitter = ThreadLocalRandom.current().nextInt(Math.max(1, (int) (delayInMs * 0.1))); + try { + TimeUnit.MILLISECONDS.sleep(delayInMs + jitter); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw propagate(ie); + } + } + } + } + + private RuntimeException propagate(Exception e) + { + Exception classified = classifier.apply(e); + throwIfUnchecked(classified); + throw new RuntimeException(classified); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftConnectorConfig.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftConnectorConfig.java new file mode 100644 index 0000000000000..562a1b3bc9081 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftConnectorConfig.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.testing.ConfigAssertions; +import io.airlift.units.DataSize; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; + +public class TestThriftConnectorConfig +{ + @Test + public void testDefaults() + { + ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(ThriftConnectorConfig.class) + .setMaxResponseSize(new DataSize(16, MEGABYTE)) + .setMetadataRefreshThreads(1) + ); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("presto-thrift.max-response-size", "2MB") + .put("presto-thrift.metadata-refresh-threads", "10") + .build(); + + ThriftConnectorConfig expected = new ThriftConnectorConfig() + .setMaxResponseSize(new DataSize(2, MEGABYTE)) + .setMetadataRefreshThreads(10); + + ConfigAssertions.assertFullMapping(properties, expected); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftPlugin.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftPlugin.java new file mode 100644 index 0000000000000..77c2b410ced78 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/TestThriftPlugin.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.testing.TestingConnectorContext; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.ServiceLoader; + +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.testing.Assertions.assertInstanceOf; +import static org.testng.Assert.assertNotNull; + +public class TestThriftPlugin +{ + @Test + public void testPlugin() + throws Exception + { + ThriftPlugin plugin = loadPlugin(ThriftPlugin.class); + + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + assertInstanceOf(factory, ThriftConnectorFactory.class); + + Map config = ImmutableMap.of("static-location.hosts", "localhost:7777"); + + Connector connector = factory.create("test", config, new TestingConnectorContext()); + assertNotNull(connector); + assertInstanceOf(connector, ThriftConnector.class); + } + + @SuppressWarnings("unchecked") + private static T loadPlugin(Class clazz) + { + for (Plugin plugin : ServiceLoader.load(Plugin.class)) { + if (clazz.isInstance(plugin)) { + return (T) plugin; + } + } + throw new AssertionError("did not find plugin: " + clazz.getName()); + } +} diff --git a/presto-hive-cdh5/src/test/java/com/facebook/presto/hive/TestHiveClient.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftDistributedQueries.java similarity index 52% rename from presto-hive-cdh5/src/test/java/com/facebook/presto/hive/TestHiveClient.java rename to presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftDistributedQueries.java index 9946c031c6f06..90e240898135e 100644 --- a/presto-hive-cdh5/src/test/java/com/facebook/presto/hive/TestHiveClient.java +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftDistributedQueries.java @@ -11,20 +11,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.hive; +package com.facebook.presto.connector.thrift.integration; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Parameters; -import org.testng.annotations.Test; +import com.facebook.presto.tests.AbstractTestQueries; -@Test(groups = "hive") -public class TestHiveClient - extends AbstractTestHiveClient +import static com.facebook.presto.connector.thrift.integration.ThriftQueryRunner.createThriftQueryRunner; + +public class TestThriftDistributedQueries + extends AbstractTestQueries { - @Parameters({"hive.cdh5.metastoreHost", "hive.cdh5.metastorePort", "hive.cdh5.databaseName", "hive.cdh5.timeZone"}) - @BeforeClass - public void initialize(String host, int port, String databaseName, String timeZone) + public TestThriftDistributedQueries() + throws Exception + { + super(() -> createThriftQueryRunner(3, 3)); + } + + @Override + public void testAssignUniqueId() { - setup(host, port, databaseName, timeZone); + // this test can take a long time } } diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftIntegrationSmokeTest.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftIntegrationSmokeTest.java new file mode 100644 index 0000000000000..098e78902c66e --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/TestThriftIntegrationSmokeTest.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.integration; + +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; +import org.testng.annotations.Test; + +import static com.facebook.presto.connector.thrift.integration.ThriftQueryRunner.createThriftQueryRunner; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.tests.QueryAssertions.assertContains; + +public class TestThriftIntegrationSmokeTest + extends AbstractTestIntegrationSmokeTest +{ + public TestThriftIntegrationSmokeTest() + throws Exception + { + super(() -> createThriftQueryRunner(2, 2)); + } + + @Override + @Test + public void testShowSchemas() + throws Exception + { + MaterializedResult actualSchemas = computeActual("SHOW SCHEMAS").toJdbcTypes(); + MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR) + .row("tiny") + .row("sf1"); + assertContains(actualSchemas, resultBuilder.build()); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java new file mode 100644 index 0000000000000..f922deafe1af9 --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.integration; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.thrift.ThriftPlugin; +import com.facebook.presto.connector.thrift.location.HostList; +import com.facebook.presto.connector.thrift.server.ThriftTpchService; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.QualifiedObjectName; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.transaction.TransactionManager; +import com.facebook.swift.codec.ThriftCodecManager; +import com.facebook.swift.service.ThriftServer; +import com.facebook.swift.service.ThriftServiceProcessor; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.log.Logger; +import io.airlift.testing.Closeables; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.Lock; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public final class ThriftQueryRunner +{ + private ThriftQueryRunner() {} + + public static QueryRunner createThriftQueryRunner(int thriftServers, int workers) + throws Exception + { + List servers = null; + DistributedQueryRunner runner = null; + try { + servers = startThriftServers(thriftServers); + runner = createThriftQueryRunnerInternal(servers, workers); + return new ThriftQueryRunnerWithServers(runner, servers); + } + catch (Throwable t) { + Closeables.closeQuietly(runner); + // runner might be null, so closing servers explicitly + if (servers != null) { + for (ThriftServer server : servers) { + Closeables.closeQuietly(server); + } + } + throw t; + } + } + + public static void main(String[] args) + throws Exception + { + ThriftQueryRunnerWithServers queryRunner = (ThriftQueryRunnerWithServers) createThriftQueryRunner(3, 3); + Thread.sleep(10); + Logger log = Logger.get(ThriftQueryRunner.class); + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } + + private static List startThriftServers(int thriftServers) + { + List servers = new ArrayList<>(thriftServers); + for (int i = 0; i < thriftServers; i++) { + ThriftServiceProcessor processor = new ThriftServiceProcessor(new ThriftCodecManager(), ImmutableList.of(), new ThriftTpchService()); + servers.add(new ThriftServer(processor).start()); + } + return servers; + } + + private static DistributedQueryRunner createThriftQueryRunnerInternal(List servers, int workers) + throws Exception + { + List addresses = servers.stream() + .map(server -> HostAddress.fromParts("localhost", server.getPort())) + .collect(toImmutableList()); + HostList hosts = HostList.fromList(addresses); + + Session defaultSession = testSessionBuilder() + .setCatalog("thrift") + .setSchema("tiny") + .build(); + DistributedQueryRunner queryRunner = new DistributedQueryRunner(defaultSession, workers); + queryRunner.installPlugin(new ThriftPlugin()); + Map connectorProperties = ImmutableMap.of( + "static-location.hosts", hosts.stringValue(), + "PrestoThriftService.thrift.client.connect-timeout", "30s" + ); + queryRunner.createCatalog("thrift", "presto-thrift", connectorProperties); + return queryRunner; + } + + /** + * Wraps QueryRunner and a list of ThriftServers to clean them up together. + */ + private static class ThriftQueryRunnerWithServers + implements QueryRunner + { + private final DistributedQueryRunner source; + private final List thriftServers; + + private ThriftQueryRunnerWithServers(DistributedQueryRunner source, List thriftServers) + { + this.source = requireNonNull(source, "source is null"); + this.thriftServers = ImmutableList.copyOf(requireNonNull(thriftServers, "thriftServers is null")); + } + + public TestingPrestoServer getCoordinator() + { + return source.getCoordinator(); + } + + @Override + public void close() + { + Closeables.closeQuietly(source); + for (ThriftServer server : thriftServers) { + Closeables.closeQuietly(server); + } + } + + @Override + public int getNodeCount() + { + return source.getNodeCount(); + } + + @Override + public Session getDefaultSession() + { + return source.getDefaultSession(); + } + + @Override + public TransactionManager getTransactionManager() + { + return source.getTransactionManager(); + } + + @Override + public Metadata getMetadata() + { + return source.getMetadata(); + } + + @Override + public CostCalculator getCostCalculator() + { + return source.getCostCalculator(); + } + + @Override + public TestingAccessControlManager getAccessControl() + { + return source.getAccessControl(); + } + + @Override + public MaterializedResult execute(String sql) + { + return source.execute(sql); + } + + @Override + public MaterializedResult execute(Session session, String sql) + { + return source.execute(session, sql); + } + + @Override + public List listTables(Session session, String catalog, String schema) + { + return source.listTables(session, catalog, schema); + } + + @Override + public boolean tableExists(Session session, String table) + { + return source.tableExists(session, table); + } + + @Override + public void installPlugin(Plugin plugin) + { + source.installPlugin(plugin); + } + + @Override + public void createCatalog(String catalogName, String connectorName, Map properties) + { + source.createCatalog(catalogName, connectorName, properties); + } + + @Override + public Lock getExclusiveLock() + { + return source.getExclusiveLock(); + } + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationConfig.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationConfig.java new file mode 100644 index 0000000000000..1d7968740369c --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationConfig.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; +import com.google.common.collect.ImmutableMap; +import io.airlift.configuration.testing.ConfigAssertions; +import org.testng.annotations.Test; + +import java.util.Map; + +public class TestStaticLocationConfig +{ + @Test + public void testDefaults() + { + ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(StaticLocationConfig.class) + .setHosts(null)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("static-location.hosts", "localhost:7777,localhost:7779") + .build(); + + StaticLocationConfig expected = new StaticLocationConfig() + .setHosts(HostList.of( + HostAddress.fromParts("localhost", 7777), + HostAddress.fromParts("localhost", 7779))); + + ConfigAssertions.assertFullMapping(properties, expected); + } +} diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationProvider.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationProvider.java new file mode 100644 index 0000000000000..25e9f8b21da7f --- /dev/null +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/location/TestStaticLocationProvider.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.location; + +import com.facebook.presto.spi.HostAddress; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.testng.Assert.assertEquals; + +public class TestStaticLocationProvider +{ + @Test + public void testGetAnyHostRoundRobin() + throws Exception + { + List expected = ImmutableList.of( + HostAddress.fromParts("localhost1", 11111), + HostAddress.fromParts("localhost2", 22222), + HostAddress.fromParts("localhost3", 33333)); + HostLocationProvider provider = new StaticLocationProvider(new StaticLocationConfig().setHosts(HostList.fromList(expected))); + List actual = new ArrayList<>(expected.size()); + for (int i = 0; i < expected.size(); i++) { + actual.add(provider.getAnyHost()); + } + assertEquals(ImmutableSet.copyOf(actual), ImmutableSet.copyOf(expected)); + } +} diff --git a/presto-thrift-testing-server/.build-airlift b/presto-thrift-testing-server/.build-airlift new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/presto-thrift-testing-server/README.txt b/presto-thrift-testing-server/README.txt new file mode 100644 index 0000000000000..cf72ed369edd2 --- /dev/null +++ b/presto-thrift-testing-server/README.txt @@ -0,0 +1 @@ +Thrift server implementing Thrift Connector API using TPCH data. diff --git a/presto-thrift-testing-server/etc/config.properties b/presto-thrift-testing-server/etc/config.properties new file mode 100644 index 0000000000000..afd3bbe56f6bf --- /dev/null +++ b/presto-thrift-testing-server/etc/config.properties @@ -0,0 +1,2 @@ +thrift.port=7779 +thrift.max-frame-size=64MB diff --git a/presto-thrift-testing-server/etc/log.properties b/presto-thrift-testing-server/etc/log.properties new file mode 100644 index 0000000000000..290ff616938c3 --- /dev/null +++ b/presto-thrift-testing-server/etc/log.properties @@ -0,0 +1 @@ +com.facebook.presto=DEBUG diff --git a/presto-thrift-testing-server/pom.xml b/presto-thrift-testing-server/pom.xml new file mode 100644 index 0000000000000..1ae4143a47764 --- /dev/null +++ b/presto-thrift-testing-server/pom.xml @@ -0,0 +1,96 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.180-SNAPSHOT + + + presto-thrift-testing-server + presto-thrift-testing-server + Presto - Thrift Testing Server + + + ${project.parent.basedir} + com.facebook.presto.connector.thrift.server.ThriftTpchServer + + + + + com.facebook.presto + presto-thrift-connector-api + + + + com.google.guava + guava + + + + com.google.code.findbugs + annotations + + + + com.facebook.swift + swift-codec + + + + com.facebook.swift + swift-service + + + + com.facebook.presto + presto-tpch + + + + io.airlift.tpch + tpch + + + + io.airlift + log + + + + io.airlift + bootstrap + + + + com.google.inject + guice + + + + javax.annotation + javax.annotation-api + + + + io.airlift + concurrent + + + + com.facebook.presto + presto-spi + + + + com.fasterxml.jackson.core + jackson-annotations + + + + io.airlift + json + + + diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/SplitInfo.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/SplitInfo.java new file mode 100644 index 0000000000000..9b295035b88af --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/SplitInfo.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public final class SplitInfo +{ + private final String schemaName; + private final String tableName; + private final int partNumber; + private final int totalParts; + + @JsonCreator + public SplitInfo( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("partNumber") int partNumber, + @JsonProperty("totalParts") int totalParts) + { + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); + this.partNumber = partNumber; + this.totalParts = totalParts; + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public int getPartNumber() + { + return partNumber; + } + + @JsonProperty + public int getTotalParts() + { + return totalParts; + } +} diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServer.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServer.java new file mode 100644 index 0000000000000..c46cae0637a89 --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServer.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.facebook.swift.codec.guice.ThriftCodecModule; +import com.facebook.swift.service.guice.ThriftClientModule; +import com.facebook.swift.service.guice.ThriftServerModule; +import com.facebook.swift.service.guice.ThriftServerStatsModule; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; +import io.airlift.bootstrap.Bootstrap; +import io.airlift.log.Logger; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public final class ThriftTpchServer +{ + private ThriftTpchServer() + { + } + + public static void start(List extraModules) + throws Exception + { + Bootstrap app = new Bootstrap( + ImmutableList.builder() + .add(new ThriftCodecModule()) + .add(new ThriftClientModule()) + .add(new ThriftServerModule()) + .add(new ThriftServerStatsModule()) + .add(new ThriftTpchServerModule()) + .addAll(requireNonNull(extraModules, "extraModules is null")) + .build()); + app.strictConfig().initialize(); + } + + public static void main(String[] args) + { + try { + ThriftTpchServer.start(ImmutableList.of()); + } + catch (Throwable t) { + Logger.get(ThriftTpchServer.class).error(t); + System.exit(1); + } + } +} diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServerModule.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServerModule.java new file mode 100644 index 0000000000000..9e0d292b3ab73 --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchServerModule.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.swift.service.guice.ThriftServiceExporter.thriftServerBinder; + +public class ThriftTpchServerModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ThriftTpchService.class).in(Scopes.SINGLETON); + thriftServerBinder(binder).exportThriftService(ThriftTpchService.class); + } +} diff --git a/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java new file mode 100644 index 0000000000000..dcc40d2eb059f --- /dev/null +++ b/presto-thrift-testing-server/src/main/java/com/facebook/presto/connector/thrift/server/ThriftTpchService.java @@ -0,0 +1,275 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.thrift.server; + +import com.facebook.presto.connector.thrift.api.PrestoThriftBlock; +import com.facebook.presto.connector.thrift.api.PrestoThriftColumnMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftId; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableColumnSet; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableSchemaName; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftNullableToken; +import com.facebook.presto.connector.thrift.api.PrestoThriftPageResult; +import com.facebook.presto.connector.thrift.api.PrestoThriftSchemaTableName; +import com.facebook.presto.connector.thrift.api.PrestoThriftService; +import com.facebook.presto.connector.thrift.api.PrestoThriftServiceException; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplit; +import com.facebook.presto.connector.thrift.api.PrestoThriftSplitBatch; +import com.facebook.presto.connector.thrift.api.PrestoThriftTableMetadata; +import com.facebook.presto.connector.thrift.api.PrestoThriftTupleDomain; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.Page; +import com.facebook.presto.spi.RecordPageSource; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.tpch.TpchMetadata; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import io.airlift.json.JsonCodec; +import io.airlift.tpch.TpchColumn; +import io.airlift.tpch.TpchColumnType; +import io.airlift.tpch.TpchEntity; +import io.airlift.tpch.TpchTable; + +import javax.annotation.Nullable; +import javax.annotation.PreDestroy; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.connector.thrift.api.PrestoThriftBlock.fromBlock; +import static com.facebook.presto.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; +import static com.facebook.presto.tpch.TpchMetadata.getPrestoType; +import static com.facebook.presto.tpch.TpchRecordSet.createTpchRecordSet; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; +import static io.airlift.concurrent.Threads.threadsNamed; +import static io.airlift.json.JsonCodec.jsonCodec; +import static java.lang.Math.min; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.stream.Collectors.toList; + +public class ThriftTpchService + implements PrestoThriftService +{ + private static final int DEFAULT_NUMBER_OF_SPLITS = 3; + private static final List SCHEMAS = ImmutableList.of("tiny", "sf1"); + private static final JsonCodec SPLIT_INFO_CODEC = jsonCodec(SplitInfo.class); + + private final ListeningExecutorService splitsExecutor = + listeningDecorator(newCachedThreadPool(threadsNamed("splits-generator-%s"))); + private final ListeningExecutorService dataExecutor = + listeningDecorator(newCachedThreadPool(threadsNamed("data-generator-%s"))); + + @Override + public List listSchemaNames() + { + return SCHEMAS; + } + + @Override + public List listTables(PrestoThriftNullableSchemaName schemaNameOrNull) + { + List tables = new ArrayList<>(); + for (String schemaName : getSchemaNames(schemaNameOrNull.getSchemaName())) { + for (TpchTable tpchTable : TpchTable.getTables()) { + tables.add(new PrestoThriftSchemaTableName(schemaName, tpchTable.getTableName())); + } + } + return tables; + } + + private static List getSchemaNames(String schemaNameOrNull) + { + if (schemaNameOrNull == null) { + return SCHEMAS; + } + else if (SCHEMAS.contains(schemaNameOrNull)) { + return ImmutableList.of(schemaNameOrNull); + } + else { + return ImmutableList.of(); + } + } + + @Override + public PrestoThriftNullableTableMetadata getTableMetadata(PrestoThriftSchemaTableName schemaTableName) + { + String schemaName = schemaTableName.getSchemaName(); + String tableName = schemaTableName.getTableName(); + if (!SCHEMAS.contains(schemaName) || TpchTable.getTables().stream().noneMatch(table -> table.getTableName().equals(tableName))) { + return new PrestoThriftNullableTableMetadata(null); + } + TpchTable tpchTable = TpchTable.getTable(schemaTableName.getTableName()); + List columns = new ArrayList<>(); + for (TpchColumn column : tpchTable.getColumns()) { + columns.add(new PrestoThriftColumnMetadata(column.getSimplifiedColumnName(), getTypeString(column.getType()), null, false)); + } + return new PrestoThriftNullableTableMetadata(new PrestoThriftTableMetadata(schemaTableName, columns, null)); + } + + @Override + public ListenableFuture getSplits( + PrestoThriftSchemaTableName schemaTableName, + PrestoThriftNullableColumnSet desiredColumns, + PrestoThriftTupleDomain outputConstraint, + int maxSplitCount, + PrestoThriftNullableToken nextToken) + throws PrestoThriftServiceException + { + return splitsExecutor.submit(() -> getSplitsInternal(schemaTableName, maxSplitCount, nextToken.getToken())); + } + + private static PrestoThriftSplitBatch getSplitsInternal( + PrestoThriftSchemaTableName schemaTableName, + int maxSplitCount, + @Nullable PrestoThriftId nextToken) + { + int totalParts = DEFAULT_NUMBER_OF_SPLITS; + // last sent part + int partNumber = nextToken == null ? 0 : Ints.fromByteArray(nextToken.getId()); + int numberOfSplits = min(maxSplitCount, totalParts - partNumber); + + List splits = new ArrayList<>(numberOfSplits); + for (int i = 0; i < numberOfSplits; i++) { + SplitInfo splitInfo = new SplitInfo( + schemaTableName.getSchemaName(), + schemaTableName.getTableName(), + partNumber + 1, + totalParts); + splits.add(new PrestoThriftSplit(new PrestoThriftId(SPLIT_INFO_CODEC.toJsonBytes(splitInfo)), ImmutableList.of())); + partNumber++; + } + PrestoThriftId newNextToken = partNumber < totalParts ? new PrestoThriftId(Ints.toByteArray(partNumber)) : null; + return new PrestoThriftSplitBatch(splits, newNextToken); + } + + @Override + public ListenableFuture getRows( + PrestoThriftId splitId, + List columns, + long maxBytes, + PrestoThriftNullableToken nextToken) + { + return dataExecutor.submit(() -> getRowsInternal(splitId, columns, maxBytes, nextToken.getToken())); + } + + @PreDestroy + @Override + public void close() + { + splitsExecutor.shutdownNow(); + dataExecutor.shutdownNow(); + } + + private static PrestoThriftPageResult getRowsInternal(PrestoThriftId splitId, List columnNames, long maxBytes, @Nullable PrestoThriftId nextToken) + { + checkArgument(maxBytes >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES, "requested maxBytes is too small"); + SplitInfo splitInfo = SPLIT_INFO_CODEC.fromJson(splitId.getId()); + ConnectorPageSource pageSource = createPageSource(splitInfo, columnNames); + + // very inefficient implementation as it needs to re-generate all previous results to get the next page + int skipPages = nextToken != null ? Ints.fromByteArray(nextToken.getId()) : 0; + skipPages(pageSource, skipPages); + + Page page = null; + while (!pageSource.isFinished() && page == null) { + page = pageSource.getNextPage(); + skipPages++; + } + PrestoThriftId newNextToken = pageSource.isFinished() ? null : new PrestoThriftId(Ints.toByteArray(skipPages)); + + return toThriftPage(page, types(splitInfo.getTableName(), columnNames), newNextToken); + } + + private static PrestoThriftPageResult toThriftPage(Page page, List columnTypes, @Nullable PrestoThriftId nextToken) + { + if (page == null) { + checkState(nextToken == null, "there must be no more data when page is null"); + return new PrestoThriftPageResult(ImmutableList.of(), 0, null); + } + checkState(page.getChannelCount() == columnTypes.size(), "number of columns in a page doesn't match the one in requested types"); + int numberOfColumns = columnTypes.size(); + List columnBlocks = new ArrayList<>(numberOfColumns); + for (int i = 0; i < numberOfColumns; i++) { + columnBlocks.add(fromBlock(page.getBlock(i), columnTypes.get(i))); + } + return new PrestoThriftPageResult(columnBlocks, page.getPositionCount(), nextToken); + } + + private static void skipPages(ConnectorPageSource pageSource, int skipPages) + { + for (int i = 0; i < skipPages; i++) { + checkState(!pageSource.isFinished(), "pageSource is unexpectedly finished"); + pageSource.getNextPage(); + } + } + + private static ConnectorPageSource createPageSource(SplitInfo splitInfo, List columnNames) + { + switch (splitInfo.getTableName()) { + case "orders": + return createPageSource(TpchTable.ORDERS, columnNames, splitInfo); + case "customer": + return createPageSource(TpchTable.CUSTOMER, columnNames, splitInfo); + case "lineitem": + return createPageSource(TpchTable.LINE_ITEM, columnNames, splitInfo); + case "nation": + return createPageSource(TpchTable.NATION, columnNames, splitInfo); + case "region": + return createPageSource(TpchTable.REGION, columnNames, splitInfo); + case "part": + return createPageSource(TpchTable.PART, columnNames, splitInfo); + default: + throw new IllegalArgumentException("Table not setup: " + splitInfo.getTableName()); + } + } + + private static ConnectorPageSource createPageSource(TpchTable table, List columnNames, SplitInfo splitInfo) + { + List> columns = columnNames.stream().map(table::getColumn).collect(toList()); + return new RecordPageSource(createTpchRecordSet( + table, + columns, + schemaNameToScaleFactor(splitInfo.getSchemaName()), + splitInfo.getPartNumber(), + splitInfo.getTotalParts(), + Optional.empty())); + } + + private static List types(String tableName, List columnNames) + { + TpchTable table = TpchTable.getTable(tableName); + return columnNames.stream().map(name -> getPrestoType(table.getColumn(name).getType())).collect(toList()); + } + + private static double schemaNameToScaleFactor(String schemaName) + { + switch (schemaName) { + case "tiny": + return 0.01; + case "sf1": + return 1.0; + } + throw new IllegalArgumentException("Schema is not setup: " + schemaName); + } + + private static String getTypeString(TpchColumnType tpchType) + { + return TpchMetadata.getPrestoType(tpchType).getTypeSignature().toString(); + } +} diff --git a/presto-tpch/pom.xml b/presto-tpch/pom.xml index 0a45cf7f28925..baa0182b7c951 100644 --- a/presto-tpch/pom.xml +++ b/presto-tpch/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-tpch diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java index e520c1bce2b67..a49cd657d8d91 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchConnectorFactory.java @@ -32,16 +32,25 @@ public class TpchConnectorFactory implements ConnectorFactory { + public static final boolean DEFAULT_PREDICATE_PUSHDOWN_ENABLED = false; + private final int defaultSplitsPerNode; + private final boolean defaultPredicatePushdownEnabled; public TpchConnectorFactory() { - this(Runtime.getRuntime().availableProcessors()); + this(Runtime.getRuntime().availableProcessors(), DEFAULT_PREDICATE_PUSHDOWN_ENABLED); } public TpchConnectorFactory(int defaultSplitsPerNode) + { + this(defaultSplitsPerNode, DEFAULT_PREDICATE_PUSHDOWN_ENABLED); + } + + public TpchConnectorFactory(int defaultSplitsPerNode, boolean defaultPredicatePushdownEnabled) { this.defaultSplitsPerNode = defaultSplitsPerNode; + this.defaultPredicatePushdownEnabled = defaultPredicatePushdownEnabled; } @Override @@ -60,6 +69,7 @@ public ConnectorHandleResolver getHandleResolver() public Connector create(String connectorId, Map properties, ConnectorContext context) { int splitsPerNode = getSplitsPerNode(properties); + boolean predicatePushdownEnabled = isPredicatePushdownEnabled(properties); ColumnNaming columnNaming = ColumnNaming.valueOf(properties.getOrDefault("tpch.column-naming", ColumnNaming.SIMPLIFIED.name()).toUpperCase()); NodeManager nodeManager = context.getNodeManager(); @@ -74,7 +84,7 @@ public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel @Override public ConnectorMetadata getMetadata(ConnectorTransactionHandle transaction) { - return new TpchMetadata(connectorId, columnNaming); + return new TpchMetadata(connectorId, predicatePushdownEnabled, columnNaming); } @Override @@ -106,4 +116,9 @@ private int getSplitsPerNode(Map properties) throw new IllegalArgumentException("Invalid property tpch.splits-per-node"); } } + + private boolean isPredicatePushdownEnabled(Map properties) + { + return Boolean.parseBoolean(firstNonNull(properties.get("tpch.predicate-pushdown"), String.valueOf(defaultPredicatePushdownEnabled))); + } } diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java index d7a6b58eca313..066dcc83c9a5c 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchMetadata.java @@ -29,7 +29,11 @@ import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.predicate.Domain; +import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.statistics.Estimate; +import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.DateType; import com.facebook.presto.spi.type.DoubleType; @@ -38,6 +42,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.airlift.tpch.LineItemColumn; import io.airlift.tpch.OrderColumn; import io.airlift.tpch.OrderGenerator; @@ -50,10 +56,13 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toSet; public class TpchMetadata implements ConnectorMetadata @@ -66,16 +75,22 @@ public class TpchMetadata public static final String ROW_NUMBER_COLUMN_NAME = "row_number"; + private static final Set ORDER_STATUS_VALUES = ImmutableSet.of("F", "O", "P"); + private static final Set ORDER_STATUS_NULLABLE_VALUES = ORDER_STATUS_VALUES.stream() + .map(value -> new NullableValue(getPrestoType(OrderColumn.ORDER_STATUS.getType()), Slices.utf8Slice(value))) + .collect(toSet()); + private final String connectorId; private final Set tableNames; + private final boolean predicatePushdownEnabled; private final ColumnNaming columnNaming; public TpchMetadata(String connectorId) { - this(connectorId, ColumnNaming.SIMPLIFIED); + this(connectorId, TpchConnectorFactory.DEFAULT_PREDICATE_PUSHDOWN_ENABLED, ColumnNaming.SIMPLIFIED); } - public TpchMetadata(String connectorId, ColumnNaming columnNaming) + public TpchMetadata(String connectorId, boolean predicatePushdownEnabled, ColumnNaming columnNaming) { ImmutableSet.Builder tableNames = ImmutableSet.builder(); for (TpchTable tpchTable : TpchTable.getTables()) { @@ -83,6 +98,7 @@ public TpchMetadata(String connectorId, ColumnNaming columnNaming) } this.tableNames = tableNames.build(); this.connectorId = connectorId; + this.predicatePushdownEnabled = predicatePushdownEnabled; this.columnNaming = columnNaming; } @@ -122,6 +138,8 @@ public List getTableLayouts( Optional> partitioningColumns = Optional.empty(); List> localProperties = ImmutableList.of(); + Optional> predicate = Optional.empty(); + TupleDomain unenforcedConstraint = constraint.getSummary(); Map columns = getColumnHandles(session, tableHandle); if (tableHandle.getTableName().equals(TpchTable.ORDERS.getTableName())) { ColumnHandle orderKeyColumn = columns.get(columnNaming.getName(OrderColumn.ORDER_KEY)); @@ -132,6 +150,15 @@ public List getTableLayouts( ImmutableList.of(orderKeyColumn))); partitioningColumns = Optional.of(ImmutableSet.of(orderKeyColumn)); localProperties = ImmutableList.of(new SortingProperty<>(orderKeyColumn, SortOrder.ASC_NULLS_FIRST)); + + if (predicatePushdownEnabled) { + predicate = Optional.of(toTupleDomain(ImmutableMap.of( + toColumnHandle(OrderColumn.ORDER_STATUS), + ORDER_STATUS_NULLABLE_VALUES.stream() + .filter(convertToPredicate(constraint.getSummary(), OrderColumn.ORDER_STATUS)) + .collect(toSet())))); + unenforcedConstraint = filterOutColumnFromPredicate(constraint.getSummary(), OrderColumn.ORDER_STATUS); + } } else if (tableHandle.getTableName().equals(TpchTable.LINE_ITEM.getTableName())) { ColumnHandle orderKeyColumn = columns.get(columnNaming.getName(LineItemColumn.ORDER_KEY)); @@ -147,15 +174,15 @@ else if (tableHandle.getTableName().equals(TpchTable.LINE_ITEM.getTableName())) } ConnectorTableLayout layout = new ConnectorTableLayout( - new TpchTableLayoutHandle(tableHandle), + new TpchTableLayoutHandle(tableHandle, predicate), Optional.empty(), - TupleDomain.all(), // TODO: return well-known properties (e.g., orderkey > 0, etc) + predicate.orElse(TupleDomain.all()), // TODO: return well-known properties (e.g., orderkey > 0, etc) nodePartition, partitioningColumns, Optional.empty(), localProperties); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return ImmutableList.of(new ConnectorTableLayoutResult(layout, unenforcedConstraint)); } @Override @@ -217,6 +244,63 @@ public Map> listTableColumns(ConnectorSess return tableColumns.build(); } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint) + { + TpchTableHandle table = (TpchTableHandle) tableHandle; + return new TableStatistics(new Estimate(getRowCount(table, Optional.of(constraint.getSummary()))), ImmutableMap.of()); + } + + private long getRowCount(TpchTableHandle tpchTableHandle, Optional> predicate) + { + // todo expose row counts from airlift-tpch instead of hardcoding it here + // todo add stats for columns + String tableName = tpchTableHandle.getTableName(); + double scaleFactor = tpchTableHandle.getScaleFactor(); + switch (tableName.toLowerCase()) { + case "customer": + return (long) (150_000 * scaleFactor); + case "orders": + Set orderStatusValues = predicate.map(tupleDomain -> + ORDER_STATUS_NULLABLE_VALUES.stream() + .filter(convertToPredicate(tupleDomain, OrderColumn.ORDER_STATUS)) + .map(nullableValue -> ((Slice) nullableValue.getValue()).toStringUtf8()) + .collect(toSet())) + .orElse(ORDER_STATUS_VALUES); + + long totalRows = 0L; + if (orderStatusValues.contains("F")) { + totalRows = 729_413; + } + if (orderStatusValues.contains("O")) { + totalRows += 732_044; + } + if (orderStatusValues.contains("P")) { + totalRows += 38_543; + } + return (long) (totalRows * scaleFactor); + case "lineitem": + return (long) (6_000_000 * scaleFactor); + case "part": + return (long) (200_000 * scaleFactor); + case "partsupp": + return (long) (800_000 * scaleFactor); + case "supplier": + return (long) (10_000 * scaleFactor); + case "nation": + return 25; + case "region": + return 5; + default: + throw new IllegalArgumentException("unknown tpch table name '" + tableName + "'"); + } + } + + private TpchColumnHandle toColumnHandle(TpchColumn column) + { + return new TpchColumnHandle(columnNaming.getName(column), getPrestoType(column.getType())); + } + @Override public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) { @@ -243,6 +327,40 @@ public List listTables(ConnectorSession session, String schemaN return builder.build(); } + private TupleDomain toTupleDomain(Map> predicate) + { + return TupleDomain.withColumnDomains(predicate.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> { + Type type = entry.getKey().getType(); + return entry.getValue().stream() + .map(nullableValue -> Domain.singleValue(type, nullableValue.getValue())) + .reduce((Domain::union)) + .orElse(Domain.none(type)); + }))); + } + + private Predicate convertToPredicate(TupleDomain predicate, TpchColumn column) + { + return nullableValue -> predicate.contains(TupleDomain.fromFixedValues(ImmutableMap.of(toColumnHandle(column), nullableValue))); + } + + private TupleDomain filterOutColumnFromPredicate(TupleDomain predicate, TpchColumn column) + { + return filterColumns(predicate, tpchColumnHandle -> !tpchColumnHandle.equals(toColumnHandle(column))); + } + + private TupleDomain filterColumns(TupleDomain predicate, Predicate filterPredicate) + { + return predicate.transform(columnHandle -> { + TpchColumnHandle tpchColumnHandle = (TpchColumnHandle) columnHandle; + if (filterPredicate.test(tpchColumnHandle)) { + return tpchColumnHandle; + } + + return null; + }); + } + private List getSchemaNames(ConnectorSession session, String schemaNameOrNull) { List schemaNames; diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java index 8dafda4221ed2..21e352622da77 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSet.java @@ -13,10 +13,15 @@ */ package com.facebook.presto.tpch; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.RecordSet; +import com.facebook.presto.spi.predicate.NullableValue; +import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.tpch.TpchColumn; @@ -26,18 +31,24 @@ import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Predicate; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.tpch.TpchMetadata.getPrestoType; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Iterables.transform; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; public class TpchRecordSet implements RecordSet { public static TpchRecordSet createTpchRecordSet(TpchTable table, double scaleFactor) { - return createTpchRecordSet(table, table.getColumns(), scaleFactor, 1, 1); + return createTpchRecordSet(table, table.getColumns(), scaleFactor, 1, 1, Optional.empty()); } public static TpchRecordSet createTpchRecordSet( @@ -45,16 +56,19 @@ public static TpchRecordSet createTpchRecordSet( Iterable> columns, double scaleFactor, int part, - int partCount) + int partCount, + Optional> predicate) { - return new TpchRecordSet<>(table.createGenerator(scaleFactor, part, partCount), columns); + return new TpchRecordSet<>(table.createGenerator(scaleFactor, part, partCount), columns, predicate); } private final Iterable table; private final List> columns; private final List columnTypes; + private final List columnHandles; + private final Optional>> predicate; - public TpchRecordSet(Iterable table, Iterable> columns) + public TpchRecordSet(Iterable table, Iterable> columns, Optional> predicate) { requireNonNull(table, "readerSupplier is null"); @@ -62,6 +76,16 @@ public TpchRecordSet(Iterable table, Iterable> columns) this.columns = ImmutableList.copyOf(columns); this.columnTypes = ImmutableList.copyOf(transform(columns, column -> getPrestoType(column.getType()))); + + columnHandles = this.columns.stream() + .map(column -> new TpchColumnHandle(column.getColumnName(), getPrestoType(column.getType()))) + .collect(toList()); + this.predicate = predicate.map(TpchRecordSet::convertToPredicate); + } + + static Predicate> convertToPredicate(TupleDomain tupleDomain) + { + return bindings -> tupleDomain.contains(TupleDomain.fromFixedValues(bindings)); } @Override @@ -117,14 +141,16 @@ public Type getType(int field) @Override public boolean advanceNextPosition() { - if (closed || !rows.hasNext()) { - closed = true; - row = null; - return false; + while (!closed && rows.hasNext()) { + row = rows.next(); + if (rowMatchesPredicate()) { + return true; + } } - row = rows.next(); - return true; + closed = true; + row = null; + return false; } @Override @@ -180,6 +206,40 @@ public void close() closed = true; } + private boolean rowMatchesPredicate() + { + if (!predicate.isPresent()) { + return true; + } + return predicate.get().test(rowMap()); + } + + private Map rowMap() + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (int field = 0; field < columnHandles.size(); ++field) { + Type type = columnTypes.get(field); + builder.put(columnHandles.get(field), NullableValue.of(type, getPrestoObject(field, type))); + } + return builder.build(); + } + + private Object getPrestoObject(int field, Type type) + { + if (type.getJavaType() == long.class) { + return getLong(field); + } + else if (type.getJavaType() == double.class) { + return getDouble(field); + } + else if (type.getJavaType() == Slice.class) { + return getSlice(field); + } + else { + throw new PrestoException(NOT_SUPPORTED, format("Unsupported column type %s", type.getDisplayName())); + } + } + private TpchColumn getTpchColumn(int field) { return columns.get(field); diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSetProvider.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSetProvider.java index afaaa94bf4599..7f7d5054cae10 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSetProvider.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchRecordSetProvider.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.predicate.TupleDomain; import com.google.common.collect.ImmutableList; import io.airlift.tpch.TpchColumn; import io.airlift.tpch.TpchColumnType; @@ -26,6 +27,7 @@ import io.airlift.tpch.TpchTable; import java.util.List; +import java.util.Optional; import static com.facebook.presto.tpch.TpchRecordSet.createTpchRecordSet; import static io.airlift.tpch.TpchColumnTypes.IDENTIFIER; @@ -42,7 +44,7 @@ public RecordSet getRecordSet(ConnectorTransactionHandle transaction, ConnectorS TpchTable tpchTable = TpchTable.getTable(tableName); - return getRecordSet(tpchTable, columns, tpchSplit.getTableHandle().getScaleFactor(), tpchSplit.getPartNumber(), tpchSplit.getTotalParts()); + return getRecordSet(tpchTable, columns, tpchSplit.getTableHandle().getScaleFactor(), tpchSplit.getPartNumber(), tpchSplit.getTotalParts(), tpchSplit.getPredicate()); } public RecordSet getRecordSet( @@ -50,7 +52,8 @@ public RecordSet getRecordSet( List columns, double scaleFactor, int partNumber, - int totalParts) + int totalParts, + Optional> predicate) { ImmutableList.Builder> builder = ImmutableList.builder(); for (ColumnHandle column : columns) { @@ -63,7 +66,7 @@ public RecordSet getRecordSet( } } - return createTpchRecordSet(table, builder.build(), scaleFactor, partNumber + 1, totalParts); + return createTpchRecordSet(table, builder.build(), scaleFactor, partNumber + 1, totalParts, predicate); } private static class RowNumberTpchColumn @@ -72,7 +75,7 @@ private static class RowNumberTpchColumn @Override public String getColumnName() { - throw new UnsupportedOperationException(); + return TpchMetadata.ROW_NUMBER_COLUMN_NAME; } @Override diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplit.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplit.java index c7578479da59a..81c2c0aa4affd 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplit.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplit.java @@ -13,14 +13,17 @@ */ package com.facebook.presto.tpch; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.predicate.TupleDomain; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; @@ -34,12 +37,14 @@ public class TpchSplit private final int totalParts; private final int partNumber; private final List addresses; + private final Optional> predicate; @JsonCreator public TpchSplit(@JsonProperty("tableHandle") TpchTableHandle tableHandle, @JsonProperty("partNumber") int partNumber, @JsonProperty("totalParts") int totalParts, - @JsonProperty("addresses") List addresses) + @JsonProperty("addresses") List addresses, + @JsonProperty("predicate") Optional> predicate) { checkState(partNumber >= 0, "partNumber must be >= 0"); checkState(totalParts >= 1, "totalParts must be >= 1"); @@ -49,6 +54,7 @@ public TpchSplit(@JsonProperty("tableHandle") TpchTableHandle tableHandle, this.partNumber = partNumber; this.totalParts = totalParts; this.addresses = ImmutableList.copyOf(requireNonNull(addresses, "addresses is null")); + this.predicate = requireNonNull(predicate, "predicate is null"); } @JsonProperty @@ -88,6 +94,12 @@ public List getAddresses() return addresses; } + @JsonProperty + public Optional> getPredicate() + { + return predicate; + } + @Override public boolean equals(Object obj) { @@ -100,7 +112,8 @@ public boolean equals(Object obj) TpchSplit other = (TpchSplit) obj; return Objects.equals(this.tableHandle, other.tableHandle) && Objects.equals(this.totalParts, other.totalParts) && - Objects.equals(this.partNumber, other.partNumber); + Objects.equals(this.partNumber, other.partNumber) && + Objects.equals(this.predicate, other.predicate); } @Override @@ -116,6 +129,7 @@ public String toString() .add("tableHandle", tableHandle) .add("partNumber", partNumber) .add("totalParts", totalParts) + .add("predicate", predicate) .toString(); } } diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplitManager.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplitManager.java index 59602daaff9b9..4159b87fce3da 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplitManager.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchSplitManager.java @@ -44,7 +44,8 @@ public TpchSplitManager(NodeManager nodeManager, int splitsPerNode) @Override public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableLayoutHandle layout) { - TpchTableHandle tableHandle = ((TpchTableLayoutHandle) layout).getTable(); + TpchTableLayoutHandle tableLayoutHandle = (TpchTableLayoutHandle) layout; + TpchTableHandle tableHandle = tableLayoutHandle.getTable(); Set nodes = nodeManager.getRequiredWorkerNodes(); @@ -55,7 +56,7 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, Co ImmutableList.Builder splits = ImmutableList.builder(); for (Node node : nodes) { for (int i = 0; i < splitsPerNode; i++) { - splits.add(new TpchSplit(tableHandle, partNumber, totalParts, ImmutableList.of(node.getHostAndPort()))); + splits.add(new TpchSplit(tableHandle, partNumber, totalParts, ImmutableList.of(node.getHostAndPort()), tableLayoutHandle.getPredicate())); partNumber++; } } diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchTableLayoutHandle.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchTableLayoutHandle.java index 2b03031efbfe3..cdc5061191163 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchTableLayoutHandle.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchTableLayoutHandle.java @@ -13,19 +13,25 @@ */ package com.facebook.presto.tpch; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.predicate.TupleDomain; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Optional; + public class TpchTableLayoutHandle implements ConnectorTableLayoutHandle { private final TpchTableHandle table; + private final Optional> predicate; @JsonCreator - public TpchTableLayoutHandle(@JsonProperty("table") TpchTableHandle table) + public TpchTableLayoutHandle(@JsonProperty("table") TpchTableHandle table, @JsonProperty("predicate") Optional> predicate) { this.table = table; + this.predicate = predicate; } @JsonProperty @@ -34,6 +40,12 @@ public TpchTableHandle getTable() return table; } + @JsonProperty + public Optional> getPredicate() + { + return predicate; + } + public String getConnectorId() { return table.getConnectorId(); diff --git a/presto-verifier/pom.xml b/presto-verifier/pom.xml index 6a96175c58165..7a29a67e7912c 100644 --- a/presto-verifier/pom.xml +++ b/presto-verifier/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.175-SNAPSHOT + 0.180-SNAPSHOT presto-verifier diff --git a/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java b/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java index 52a87bca96b2f..c7417197244f0 100644 --- a/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java +++ b/presto-verifier/src/main/java/com/facebook/presto/verifier/QueryRewriter.java @@ -169,7 +169,8 @@ private String createTemporaryTableName() return rewritePrefix.getSuffix() + UUID.randomUUID().toString().replace("-", ""); } - private List getColumnsForTable(Connection connection, String catalog, String schema, String table) throws SQLException + private List getColumnsForTable(Connection connection, String catalog, String schema, String table) + throws SQLException { ResultSet columns = connection.getMetaData().getColumns(catalog, escapeLikeExpression(connection, schema), escapeLikeExpression(connection, table), null); ImmutableList.Builder columnBuilder = new ImmutableList.Builder<>(); diff --git a/presto-verifier/src/main/java/com/facebook/presto/verifier/Validator.java b/presto-verifier/src/main/java/com/facebook/presto/verifier/Validator.java index 5aa5ccf648f1b..28db88891a31e 100644 --- a/presto-verifier/src/main/java/com/facebook/presto/verifier/Validator.java +++ b/presto-verifier/src/main/java/com/facebook/presto/verifier/Validator.java @@ -18,6 +18,7 @@ import com.facebook.presto.jdbc.QueryStats; import com.facebook.presto.spi.type.SqlVarbinary; import com.facebook.presto.verifier.Validator.ChangedRow.Changed; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.base.Stopwatch; import com.google.common.base.Throwables; @@ -35,7 +36,6 @@ import io.airlift.units.Duration; import java.math.BigDecimal; -import java.math.MathContext; import java.sql.Array; import java.sql.Connection; import java.sql.DriverManager; @@ -716,16 +716,33 @@ private static boolean isIntegral(Number x) return x instanceof Byte || x instanceof Short || x instanceof Integer || x instanceof Long; } - private static int precisionCompare(double a, double b, int precision) + //adapted from http://floating-point-gui.de/errors/comparison/ + private static boolean isClose(double a, double b, double epsilon) { + double absA = Math.abs(a); + double absB = Math.abs(b); + double diff = Math.abs(a - b); + if (!isFinite(a) || !isFinite(b)) { - return Double.compare(a, b); + return Double.compare(a, b) == 0; } - MathContext context = new MathContext(precision); - BigDecimal x = new BigDecimal(a).round(context); - BigDecimal y = new BigDecimal(b).round(context); - return x.compareTo(y); + // a or b is zero or both are extremely close to it + // relative error is less meaningful here + if (a == 0 || b == 0 || diff < Float.MIN_NORMAL) { + return diff < (epsilon * Float.MIN_NORMAL); + } + else { + // use relative error + return diff / Math.min((absA + absB), Float.MAX_VALUE) < epsilon; + } + } + + @VisibleForTesting + static int precisionCompare(double a, double b, int precision) + { + //we don't care whether a is smaller than b or not when they are not close since we will fail verification anyway + return isClose(a, b, Math.pow(10, -1 * (precision - 1))) ? 0 : -1; } public static class ChangedRow diff --git a/presto-verifier/src/test/java/com/facebook/presto/verifier/TestValidator.java b/presto-verifier/src/test/java/com/facebook/presto/verifier/TestValidator.java new file mode 100644 index 0000000000000..1cc8f1f204ef8 --- /dev/null +++ b/presto-verifier/src/test/java/com/facebook/presto/verifier/TestValidator.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.verifier; + +import org.testng.annotations.Test; + +import static com.facebook.presto.verifier.Validator.precisionCompare; +import static java.lang.Double.NaN; +import static org.testng.Assert.assertEquals; + +public class TestValidator +{ + @Test + public void testDoubleComparison() + throws Exception + { + assertEquals(precisionCompare(0.9045, 0.9045000000000001, 3), 0); + assertEquals(precisionCompare(0.9045, 0.9045000000000001, 2), 0); + assertEquals(precisionCompare(0.9041, 0.9042, 3), 0); + assertEquals(precisionCompare(0.9041, 0.9042, 4), 0); + assertEquals(precisionCompare(0.9042, 0.9041, 4), 0); + assertEquals(precisionCompare(-0.9042, -0.9041, 4), 0); + assertEquals(precisionCompare(-0.9042, -0.9041, 3), 0); + assertEquals(precisionCompare(0.899, 0.901, 3), 0); + assertEquals(precisionCompare(NaN, NaN, 4), Double.compare(NaN, NaN)); + } +} diff --git a/src/checkstyle/checks.xml b/src/checkstyle/checks.xml index 2f34e454a9cad..e1a74d83d02c3 100644 --- a/src/checkstyle/checks.xml +++ b/src/checkstyle/checks.xml @@ -47,6 +47,10 @@ + + + +