From 813c31d1a69a6585969e11477c8d39dfc2b3da29 Mon Sep 17 00:00:00 2001 From: qqibrow Date: Wed, 6 Feb 2019 14:24:53 -0800 Subject: [PATCH] Pushdown dereference expression to paruqet reader Add PushDownDereferenceExpression to pushdown dereferences right above TableScan Add MergeNestedColumn to convert valid dereference into ColumnHandle in TableScan Add NestedColumn into HiveColumnHandle Change in ParquetReader to use NestedColumn in file reading --- .../plugin/hive/HiveColumnHandle.java | 23 +- .../prestosql/plugin/hive/HiveMetadata.java | 39 +- .../plugin/hive/HivePageSourceProvider.java | 18 +- .../io/prestosql/plugin/hive/HiveType.java | 18 + .../io/prestosql/plugin/hive/HiveUtil.java | 4 +- .../plugin/hive/orc/OrcPageSourceFactory.java | 2 +- .../hive/parquet/ParquetPageSource.java | 131 ++++++- .../parquet/ParquetPageSourceFactory.java | 19 +- .../plugin/hive/AbstractTestHiveClient.java | 10 +- .../hive/AbstractTestHiveFileFormats.java | 2 +- .../hive/TestBackgroundHiveSplitLoader.java | 2 +- .../plugin/hive/TestHiveColumnHandle.java | 4 +- .../plugin/hive/TestHiveMetadata.java | 1 + .../plugin/hive/TestHivePageSink.java | 2 +- .../prestosql/plugin/hive/TestHiveSplit.java | 2 +- .../plugin/hive/TestIonSqlQueryBuilder.java | 22 +- .../plugin/hive/TestJsonHiveHandles.java | 2 +- .../hive/TestOrcPageSourceMemoryTracking.java | 3 +- .../plugin/hive/TestS3SelectRecordCursor.java | 10 +- .../plugin/hive/benchmark/FileFormat.java | 14 +- .../parquet/AbstractTestParquetReader.java | 29 ++ .../plugin/hive/parquet/ParquetTester.java | 57 ++- .../predicate/TestParquetPredicateUtils.java | 10 +- .../TestMetastoreHiveStatisticsProvider.java | 8 +- presto-main/etc/catalog/hive.properties | 2 + .../java/io/prestosql/metadata/Metadata.java | 8 + .../prestosql/metadata/MetadataManager.java | 9 + .../prestosql/sql/planner/PlanOptimizers.java | 10 + .../iterative/rule/InlineProjections.java | 1 + .../optimizations/MergeNestedColumn.java | 316 +++++++++++++++ .../PushDownDereferenceExpression.java | 369 ++++++++++++++++++ .../io/prestosql/testing/TestingMetadata.java | 19 +- .../metadata/AbstractMockMetadata.java | 7 + .../sql/planner/TestDereferencePushDown.java | 74 ++++ .../sql/planner/TestMergeNestedColumns.java | 264 +++++++++++++ .../planner/assertions/ColumnReference.java | 7 +- .../assertions/ExpressionVerifier.java | 40 ++ .../planner/assertions/PlanMatchPattern.java | 5 + .../sql/planner/assertions/UnnestMatcher.java | 71 ++++ .../iterative/rule/test/RuleTester.java | 6 + .../prestosql/parquet/ParquetTypeUtils.java | 46 ++- .../sql/tree/DereferenceExpression.java | 5 +- .../java/io/prestosql/spi/NestedColumn.java | 93 +++++ .../spi/connector/ConnectorMetadata.java | 7 + .../ClassLoaderSafeConnectorMetadata.java | 9 + 45 files changed, 1710 insertions(+), 90 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MergeNestedColumn.java create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PushDownDereferenceExpression.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/TestMergeNestedColumns.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java create mode 100644 presto-spi/src/main/java/io/prestosql/spi/NestedColumn.java diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java index a110a3d5f2ea..1d238cc056d4 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.type.TypeManager; @@ -60,6 +61,7 @@ public enum ColumnType private final int hiveColumnIndex; private final ColumnType columnType; private final Optional comment; + private final Optional nestedColumn; @JsonCreator public HiveColumnHandle( @@ -68,7 +70,8 @@ public HiveColumnHandle( @JsonProperty("typeSignature") TypeSignature typeSignature, @JsonProperty("hiveColumnIndex") int hiveColumnIndex, @JsonProperty("columnType") ColumnType columnType, - @JsonProperty("comment") Optional comment) + @JsonProperty("comment") Optional comment, + @JsonProperty("nestedColumn") Optional nestedColumn) { this.name = requireNonNull(name, "name is null"); checkArgument(hiveColumnIndex >= 0 || columnType == PARTITION_KEY || columnType == SYNTHESIZED, "hiveColumnIndex is negative"); @@ -77,6 +80,7 @@ public HiveColumnHandle( this.typeName = requireNonNull(typeSignature, "type is null"); this.columnType = requireNonNull(columnType, "columnType is null"); this.comment = requireNonNull(comment, "comment is null"); + this.nestedColumn = requireNonNull(nestedColumn, "nestedColumn is null"); } @JsonProperty @@ -97,6 +101,12 @@ public int getHiveColumnIndex() return hiveColumnIndex; } + @JsonProperty + public Optional getNestedColumn() + { + return nestedColumn; + } + public boolean isPartitionKey() { return columnType == PARTITION_KEY; @@ -133,7 +143,7 @@ public ColumnType getColumnType() @Override public int hashCode() { - return Objects.hash(name, hiveColumnIndex, hiveType, columnType, comment); + return Objects.hash(name, hiveColumnIndex, hiveType, columnType, comment, nestedColumn); } @Override @@ -150,7 +160,8 @@ public boolean equals(Object obj) Objects.equals(this.hiveColumnIndex, other.hiveColumnIndex) && Objects.equals(this.hiveType, other.hiveType) && Objects.equals(this.columnType, other.columnType) && - Objects.equals(this.comment, other.comment); + Objects.equals(this.comment, other.comment) && + Objects.equals(this.nestedColumn, other.nestedColumn); } @Override @@ -167,12 +178,12 @@ public static HiveColumnHandle updateRowIdHandle() // plan-time support for row-by-row delete so that planning doesn't fail. This is why we need // rowid handle. Note that in Hive connector, rowid handle is not implemented beyond plan-time. - return new HiveColumnHandle(UPDATE_ROW_ID_COLUMN_NAME, HIVE_LONG, BIGINT.getTypeSignature(), -1, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(UPDATE_ROW_ID_COLUMN_NAME, HIVE_LONG, BIGINT.getTypeSignature(), -1, SYNTHESIZED, Optional.empty(), Optional.empty()); } public static HiveColumnHandle pathColumnHandle() { - return new HiveColumnHandle(PATH_COLUMN_NAME, PATH_HIVE_TYPE, PATH_TYPE_SIGNATURE, PATH_COLUMN_INDEX, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(PATH_COLUMN_NAME, PATH_HIVE_TYPE, PATH_TYPE_SIGNATURE, PATH_COLUMN_INDEX, SYNTHESIZED, Optional.empty(), Optional.empty()); } /** @@ -182,7 +193,7 @@ public static HiveColumnHandle pathColumnHandle() */ public static HiveColumnHandle bucketColumnHandle() { - return new HiveColumnHandle(BUCKET_COLUMN_NAME, BUCKET_HIVE_TYPE, BUCKET_TYPE_SIGNATURE, BUCKET_COLUMN_INDEX, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(BUCKET_COLUMN_NAME, BUCKET_HIVE_TYPE, BUCKET_TYPE_SIGNATURE, BUCKET_COLUMN_INDEX, SYNTHESIZED, Optional.empty(), Optional.empty()); } public static boolean isPathColumnHandle(HiveColumnHandle column) diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java index 5283f58fda1b..1a98534173e6 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java @@ -15,6 +15,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.base.Suppliers; import com.google.common.base.Verify; @@ -44,6 +45,7 @@ import io.prestosql.plugin.hive.metastore.Table; import io.prestosql.plugin.hive.metastore.thrift.ThriftMetastoreUtil; import io.prestosql.plugin.hive.statistics.HiveStatisticsProvider; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.StandardErrorCode; import io.prestosql.spi.block.Block; @@ -172,6 +174,7 @@ import static io.prestosql.plugin.hive.HiveUtil.decodeViewData; import static io.prestosql.plugin.hive.HiveUtil.encodeViewData; import static io.prestosql.plugin.hive.HiveUtil.getPartitionKeyColumnHandles; +import static io.prestosql.plugin.hive.HiveUtil.getRegularColumnHandles; import static io.prestosql.plugin.hive.HiveUtil.hiveColumnHandles; import static io.prestosql.plugin.hive.HiveUtil.schemaTableName; import static io.prestosql.plugin.hive.HiveUtil.toPartitionValues; @@ -616,6 +619,39 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((HiveColumnHandle) columnHandle).getColumnMetadata(typeManager); } + @Override + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection nestedColumns) + { + if (!HiveSessionProperties.isStatisticsEnabled(session)) { + return ImmutableMap.of(); + } + + SchemaTableName tableName = schemaTableName(tableHandle); + Optional table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); + + if (!table.isPresent()) { + throw new TableNotFoundException(tableName); + } + + // Only pushdown nested column for parquet table for now + if (!extractHiveStorageFormat(table.get()).equals(HiveStorageFormat.PARQUET)) { + return ImmutableMap.of(); + } + + List regularColumnHandles = getRegularColumnHandles(table.get()); + Map regularHiveColumnHandles = regularColumnHandles.stream().collect(Collectors.toMap(HiveColumnHandle::getName, identity())); + ImmutableMap.Builder columnHandles = ImmutableMap.builder(); + for (NestedColumn nestedColumn : nestedColumns) { + HiveColumnHandle hiveColumnHandle = regularHiveColumnHandles.get(nestedColumn.getBase()); + Optional childType = hiveColumnHandle.getHiveType().findChildType(nestedColumn); + Preconditions.checkArgument(childType.isPresent(), "%s doesn't exist in parent type %s", nestedColumn, hiveColumnHandle.getHiveType()); + if (hiveColumnHandle != null) { + columnHandles.put(nestedColumn, new HiveColumnHandle(nestedColumn.getName(), childType.get(), childType.get().getTypeSignature(), hiveColumnHandle.getHiveColumnIndex(), hiveColumnHandle.getColumnType(), hiveColumnHandle.getComment(), Optional.of(nestedColumn))); + } + } + return columnHandles.build(); + } + @Override public void createSchema(ConnectorSession session, String schemaName, Map properties) { @@ -2124,7 +2160,8 @@ else if (column.isHidden()) { column.getType().getTypeSignature(), ordinal, columnType, - Optional.ofNullable(column.getComment()))); + Optional.ofNullable(column.getComment()), + Optional.empty())); ordinal++; } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java index 505bb2f84ce9..acf101503a1f 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.plugin.hive.HdfsEnvironment.HdfsContext; import io.prestosql.plugin.hive.HiveSplit.BucketConversion; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorPageSource; import io.prestosql.spi.connector.ConnectorPageSourceProvider; @@ -319,8 +320,13 @@ public static List buildColumnMappings( for (HiveColumnHandle column : columns) { Optional coercionFrom = Optional.ofNullable(columnCoercions.get(column.getHiveColumnIndex())); if (column.getColumnType() == REGULAR) { - checkArgument(regularColumnIndices.add(column.getHiveColumnIndex()), "duplicate hiveColumnIndex in columns list"); - columnMappings.add(regular(column, regularIndex, coercionFrom)); + if (column.getNestedColumn().isPresent()) { + columnMappings.add(regular(column, regularIndex, getHiveType(coercionFrom, column.getNestedColumn().get()))); + } + else { + checkArgument(regularColumnIndices.add(column.getHiveColumnIndex()), "duplicate hiveColumnIndex in columns list"); + columnMappings.add(regular(column, regularIndex, coercionFrom)); + } regularIndex++; } else { @@ -344,6 +350,11 @@ public static List buildColumnMappings( return columnMappings.build(); } + private static Optional getHiveType(Optional baseType, NestedColumn nestedColumn) + { + return baseType.flatMap(type -> type.findChildType(nestedColumn)); + } + public static List extractRegularAndInterimColumnMappings(List columnMappings) { return columnMappings.stream() @@ -365,7 +376,8 @@ public static List toColumnHandles(List regular columnMapping.getCoercionFrom().get().getTypeSignature(), columnHandle.getHiveColumnIndex(), columnHandle.getColumnType(), - Optional.empty()); + Optional.empty(), + columnHandle.getNestedColumn()); }) .collect(toList()); } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java index 87d85731018e..75f085811352 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java @@ -15,7 +15,9 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.type.NamedTypeSignature; import io.prestosql.spi.type.RowFieldName; @@ -175,6 +177,22 @@ public static boolean isSupportedType(TypeInfo typeInfo) return false; } + public Optional findChildType(NestedColumn nestedColumn) + { + TypeInfo typeInfo = getTypeInfo(); + for (String part : nestedColumn.getRest()) { + Preconditions.checkArgument(typeInfo instanceof StructTypeInfo, "typeinfo is not struct type", typeInfo); + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + try { + typeInfo = structTypeInfo.getStructFieldTypeInfo(part); + } + catch (RuntimeException e) { + return Optional.empty(); + } + } + return Optional.of(toHiveType(typeInfo)); + } + @JsonCreator public static HiveType valueOf(String hiveTypeName) { diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java index 75c4292af0b1..1818e74c2342 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java @@ -825,7 +825,7 @@ public static List getRegularColumnHandles(Table table) // ignore unsupported types rather than failing HiveType hiveType = field.getType(); if (hiveType.isSupportedType()) { - columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), hiveColumnIndex, REGULAR, field.getComment())); + columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), hiveColumnIndex, REGULAR, field.getComment(), Optional.empty())); } hiveColumnIndex++; } @@ -843,7 +843,7 @@ public static List getPartitionKeyColumnHandles(Table table) if (!hiveType.isSupportedType()) { throw new PrestoException(NOT_SUPPORTED, format("Unsupported Hive type %s found in partition keys of table %s.%s", hiveType, table.getDatabaseName(), table.getTableName())); } - columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), -1, PARTITION_KEY, field.getComment())); + columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), -1, PARTITION_KEY, field.getComment(), Optional.empty())); } return columns.build(); diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java index 5d4e60b02515..650054187550 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java @@ -264,7 +264,7 @@ private static List getPhysicalHiveColumnHandles(List columnNames; private final List types; private final List> fields; + private final List> nestedColumns; private final Block[] constantBlocks; private final int[] hiveColumnIndexes; @@ -86,6 +92,7 @@ public ParquetPageSource( int size = columns.size(); this.constantBlocks = new Block[size]; this.hiveColumnIndexes = new int[size]; + this.nestedColumns = new ArrayList<>(size); ImmutableList.Builder namesBuilder = ImmutableList.builder(); ImmutableList.Builder typesBuilder = ImmutableList.builder(); @@ -99,15 +106,22 @@ public ParquetPageSource( namesBuilder.add(name); typesBuilder.add(type); + nestedColumns.add(column.getNestedColumn()); hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex(); - if (getParquetType(column, fileSchema, useParquetColumnNames) == null) { + if (getColumnType(column, fileSchema, useParquetColumnNames) == null) { constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_VECTOR_LENGTH); fieldsBuilder.add(Optional.empty()); } else { - String columnName = useParquetColumnNames ? name : fileSchema.getFields().get(column.getHiveColumnIndex()).getName(); - fieldsBuilder.add(constructField(type, lookupColumnByName(messageColumnIO, columnName))); + if (column.getNestedColumn().isPresent()) { + NestedColumn nestedColumn = column.getNestedColumn().get(); + fieldsBuilder.add(constructField(getNestedStructType(nestedColumn, type), lookupColumnByName(messageColumnIO, nestedColumn.getBase()))); + } + else { + String columnName = useParquetColumnNames ? name : fileSchema.getFields().get(column.getHiveColumnIndex()).getName(); + fieldsBuilder.add(constructField(type, lookupColumnByName(messageColumnIO, columnName))); + } } } types = typesBuilder.build(); @@ -157,19 +171,17 @@ public Page getNextPage() blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize); } else { - Type type = types.get(fieldId); Optional field = fields.get(fieldId); - int fieldIndex; - if (useParquetColumnNames) { - fieldIndex = getFieldIndex(fileSchema, columnNames.get(fieldId)); - } - else { - fieldIndex = hiveColumnIndexes[fieldId]; - } - if (fieldIndex != -1 && field.isPresent()) { - blocks[fieldId] = new LazyBlock(batchSize, new ParquetBlockLoader(field.get())); + if (field.isPresent()) { + if (nestedColumns.get(fieldId).isPresent()) { + blocks[fieldId] = new LazyBlock(batchSize, new NestedColumnParquetBlockLoader(field.get(), types.get(fieldId))); + } + else { + blocks[fieldId] = new LazyBlock(batchSize, new ParquetBlockLoader(field.get())); + } } else { + Type type = types.get(fieldId); blocks[fieldId] = RunLengthEncodedBlock.create(type, null, batchSize); } } @@ -250,4 +262,93 @@ public final void load(LazyBlock lazyBlock) loaded = true; } } + + private final class NestedColumnParquetBlockLoader + implements LazyBlockLoader + { + private final int expectedBatchId = batchId; + private final Field field; + private final Type type; + private final int level; + private boolean loaded; + + // field is group field + public NestedColumnParquetBlockLoader(Field field, Type type) + { + this.field = requireNonNull(field, "field is null"); + this.type = requireNonNull(type, "type is null"); + this.level = getLevel(field.getType(), type); + } + + int getLevel(Type rootType, Type leafType) + { + int level = 0; + Type currentType = rootType; + while (!currentType.equals(leafType)) { + currentType = currentType.getTypeParameters().get(0); + ++level; + } + return level; + } + + @Override + public final void load(LazyBlock lazyBlock) + { + if (loaded) { + return; + } + + checkState(batchId == expectedBatchId); + + try { + Block block = parquetReader.readBlock(field); + + int size = block.getPositionCount(); + boolean[] isNulls = new boolean[size]; + + for (int currentLevel = 0; currentLevel < level; ++currentLevel) { + ColumnarRow rowBlock = ColumnarRow.toColumnarRow(block); + int index = 0; + for (int j = 0; j < size; ++j) { + if (!isNulls[j]) { + isNulls[j] = rowBlock.isNull(index); + ++index; + } + } + block = rowBlock.getField(0); + } + + BlockBuilder blockBuilder = type.createBlockBuilder(null, size); + int currentPosition = 0; + for (int i = 0; i < size; ++i) { + if (isNulls[i]) { + blockBuilder.appendNull(); + } + else { + Preconditions.checkArgument(currentPosition < block.getPositionCount(), "current position cannot exceed total position count"); + type.appendTo(block, currentPosition, blockBuilder); + currentPosition++; + } + } + lazyBlock.setBlock(blockBuilder.build()); + } + catch (ParquetCorruptionException e) { + throw new PrestoException(HIVE_BAD_DATA, e); + } + catch (IOException e) { + throw new PrestoException(HIVE_CURSOR_ERROR, e); + } + loaded = true; + } + } + + private Type getNestedStructType(NestedColumn nestedColumn, Type leafType) + { + Type type = leafType; + List names = nestedColumn.getRest(); + for (int i = names.size() - 1; i >= 0; --i) { + type = RowType.from(ImmutableList.of(RowType.field(names.get(i), type))); + } + return type; + } } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java index 27c854e92e80..eeefc4a98d8e 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -63,6 +63,7 @@ import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.parquet.ParquetTypeUtils.getColumnIO; import static io.prestosql.parquet.ParquetTypeUtils.getDescriptors; +import static io.prestosql.parquet.ParquetTypeUtils.getNestedColumnType; import static io.prestosql.parquet.ParquetTypeUtils.getParquetTypeByName; import static io.prestosql.parquet.predicate.PredicateUtils.buildPredicate; import static io.prestosql.parquet.predicate.PredicateUtils.predicateMatches; @@ -77,7 +78,6 @@ import static io.prestosql.plugin.hive.parquet.HdfsParquetDataSource.buildHdfsParquetDataSource; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.PRIMITIVE; public class ParquetPageSourceFactory @@ -163,13 +163,14 @@ public static ParquetPageSource createParquetPageSource( MessageType fileSchema = fileMetaData.getSchema(); dataSource = buildHdfsParquetDataSource(inputStream, path, fileSize, stats); - List fields = columns.stream() + Optional optionalRequestedSchema = columns.stream() .filter(column -> column.getColumnType() == REGULAR) - .map(column -> getParquetType(column, fileSchema, useParquetColumnNames)) + .map(column -> getColumnType(column, fileSchema, useParquetColumnNames)) .filter(Objects::nonNull) - .collect(toList()); + .map(type -> new MessageType(fileSchema.getName(), type)) + .reduce(MessageType::union); - MessageType requestedSchema = new MessageType(fileSchema.getName(), fields); + MessageType requestedSchema = optionalRequestedSchema.orElseGet(() -> new MessageType(fileSchema.getName(), ImmutableList.of())); ImmutableList.Builder footerBlocks = ImmutableList.builder(); for (BlockMetaData block : parquetMetadata.getBlocks()) { @@ -266,4 +267,12 @@ public static org.apache.parquet.schema.Type getParquetType(HiveColumnHandle col } return null; } + + public static org.apache.parquet.schema.Type getColumnType(HiveColumnHandle column, MessageType messageType, boolean useParquetColumnNames) + { + if (useParquetColumnNames && column.getNestedColumn().isPresent()) { + return getNestedColumnType(messageType, column.getNestedColumn().get()); + } + return getParquetType(column, messageType, useParquetColumnNames); + } } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveClient.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveClient.java index 784113bb5865..d95773cae12d 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveClient.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveClient.java @@ -629,11 +629,11 @@ protected void setupHive(String connectorId, String databaseName, String timeZon Optional.empty(), Optional.empty()); - dsColumn = new HiveColumnHandle("ds", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty()); - fileFormatColumn = new HiveColumnHandle("file_format", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty()); - dummyColumn = new HiveColumnHandle("dummy", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty()); - intColumn = new HiveColumnHandle("t_int", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty()); - invalidColumnHandle = new HiveColumnHandle(INVALID_COLUMN, HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 0, REGULAR, Optional.empty()); + dsColumn = new HiveColumnHandle("ds", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + fileFormatColumn = new HiveColumnHandle("file_format", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + dummyColumn = new HiveColumnHandle("dummy", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + intColumn = new HiveColumnHandle("t_int", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + invalidColumnHandle = new HiveColumnHandle(INVALID_COLUMN, HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 0, REGULAR, Optional.empty(), Optional.empty()); List partitionColumns = ImmutableList.of(dsColumn, fileFormatColumn, dummyColumn); List partitions = ImmutableList.builder() diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java index c21693b2e2a2..3ccbb84aaf7e 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java @@ -483,7 +483,7 @@ protected List getColumnHandles(List testColumns) int columnIndex = testColumn.isPartitionKey() ? -1 : nextHiveColumnIndex++; HiveType hiveType = HiveType.valueOf(testColumn.getObjectInspector().getTypeName()); - columns.add(new HiveColumnHandle(testColumn.getName(), hiveType, hiveType.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty())); + columns.add(new HiveColumnHandle(testColumn.getName(), hiveType, hiveType.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty(), Optional.empty())); } return columns; } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java index d4c3a4cb8fdd..d0d835c9d6e2 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java @@ -94,7 +94,7 @@ public class TestBackgroundHiveSplitLoader private static final List PARTITION_COLUMNS = ImmutableList.of( new Column("partitionColumn", HIVE_INT, Optional.empty())); private static final List BUCKET_COLUMN_HANDLES = ImmutableList.of( - new HiveColumnHandle("col1", HIVE_INT, INTEGER.getTypeSignature(), 0, ColumnType.REGULAR, Optional.empty())); + new HiveColumnHandle("col1", HIVE_INT, INTEGER.getTypeSignature(), 0, ColumnType.REGULAR, Optional.empty(), Optional.empty())); private static final Optional BUCKET_PROPERTY = Optional.of( new HiveBucketProperty(ImmutableList.of("col1"), BUCKET_COUNT, ImmutableList.of())); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java index f7f097d422dc..867bef1187f9 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java @@ -38,14 +38,14 @@ public void testHiddenColumn() @Test public void testRegularColumn() { - HiveColumnHandle expectedPartitionColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty()); + HiveColumnHandle expectedPartitionColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty(), Optional.empty()); testRoundTrip(expectedPartitionColumn); } @Test public void testPartitionKeyColumn() { - HiveColumnHandle expectedRegularColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, REGULAR, Optional.empty()); + HiveColumnHandle expectedRegularColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, REGULAR, Optional.empty(), Optional.empty()); testRoundTrip(expectedRegularColumn); } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java index 8d7c515868fd..c3a9fb906224 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java @@ -34,6 +34,7 @@ public class TestHiveMetadata TypeSignature.parseTypeSignature("varchar"), 0, HiveColumnHandle.ColumnType.PARTITION_KEY, + Optional.empty(), Optional.empty()); @Test(timeOut = 5000) diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java index 9828602af8f5..5769f2be82ff 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java @@ -282,7 +282,7 @@ private static List getColumnHandles() for (int i = 0; i < columns.size(); i++) { LineItemColumn column = columns.get(i); HiveType hiveType = getHiveType(column.getType()); - handles.add(new HiveColumnHandle(column.getColumnName(), hiveType, hiveType.getTypeSignature(), i, REGULAR, Optional.empty())); + handles.add(new HiveColumnHandle(column.getColumnName(), hiveType, hiveType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); } return handles.build(); } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java index fa01cdc3ffa7..891e23570104 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java @@ -61,7 +61,7 @@ public void testJsonRoundTrip() Optional.of(new HiveSplit.BucketConversion( 32, 16, - ImmutableList.of(new HiveColumnHandle("col", HIVE_LONG, BIGINT.getTypeSignature(), 5, ColumnType.REGULAR, Optional.of("comment"))))), + ImmutableList.of(new HiveColumnHandle("col", HIVE_LONG, BIGINT.getTypeSignature(), 5, ColumnType.REGULAR, Optional.of("comment"), Optional.empty())))), false); String json = codec.toJson(expected); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java index a38b0392024d..94987d922bc5 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java @@ -56,9 +56,9 @@ public void testBuildSQL() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("n_nationkey", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("n_name", HIVE_STRING, parseTypeSignature(VARCHAR), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("n_regionkey", HIVE_INT, parseTypeSignature(INTEGER), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("n_nationkey", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("n_name", HIVE_STRING, parseTypeSignature(VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("n_regionkey", HIVE_INT, parseTypeSignature(INTEGER), 2, REGULAR, Optional.empty(), Optional.empty())); assertEquals("SELECT s._1, s._2, s._3 FROM S3Object s", queryBuilder.buildSql(columns, TupleDomain.all())); @@ -81,9 +81,9 @@ public void testDecimalColumns() TypeManager typeManager = new TypeRegistry(); IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(typeManager); List columns = ImmutableList.of( - new HiveColumnHandle("quantity", HiveType.valueOf("decimal(20,0)"), parseTypeSignature(DECIMAL), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("extendedprice", HiveType.valueOf("decimal(20,2)"), parseTypeSignature(DECIMAL), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("discount", HiveType.valueOf("decimal(10,2)"), parseTypeSignature(DECIMAL), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("quantity", HiveType.valueOf("decimal(20,0)"), parseTypeSignature(DECIMAL), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("extendedprice", HiveType.valueOf("decimal(20,2)"), parseTypeSignature(DECIMAL), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("discount", HiveType.valueOf("decimal(10,2)"), parseTypeSignature(DECIMAL), 2, REGULAR, Optional.empty(), Optional.empty())); DecimalType decimalType = DecimalType.createDecimalType(10, 2); TupleDomain tupleDomain = withColumnDomains( ImmutableMap.of( @@ -101,8 +101,8 @@ public void testDateColumn() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("t1", HIVE_TIMESTAMP, parseTypeSignature(TIMESTAMP), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("t2", HIVE_DATE, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty())); + new HiveColumnHandle("t1", HIVE_TIMESTAMP, parseTypeSignature(TIMESTAMP), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("t2", HIVE_DATE, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty(), Optional.empty())); TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of( columns.get(1), Domain.create(SortedRangeSet.copyOf(DATE, ImmutableList.of(Range.equal(DATE, (long) DateTimeUtils.parseDate("2001-08-22")))), false))); @@ -114,9 +114,9 @@ public void testNotPushDoublePredicates() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("extendedprice", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("discount", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("extendedprice", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("discount", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 2, REGULAR, Optional.empty(), Optional.empty())); TupleDomain tupleDomain = withColumnDomains( ImmutableMap.of( columns.get(0), Domain.create(ofRanges(Range.lessThan(BIGINT, 50L)), false), diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java index bdb3ee55513b..50963432674b 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java @@ -77,7 +77,7 @@ public void testTableHandleDeserialize() public void testColumnHandleSerialize() throws Exception { - HiveColumnHandle columnHandle = new HiveColumnHandle("column", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), -1, PARTITION_KEY, Optional.of("comment")); + HiveColumnHandle columnHandle = new HiveColumnHandle("column", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), -1, PARTITION_KEY, Optional.of("comment"), Optional.empty()); assertTrue(objectMapper.canSerialize(HiveColumnHandle.class)); String json = objectMapper.writeValueAsString(columnHandle); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java index 310d868627c6..25b5f9c1a13e 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java @@ -444,8 +444,7 @@ public TestPreparer(String tempFilePath, List testColumns, int numRo ObjectInspector inspector = testColumn.getObjectInspector(); HiveType hiveType = HiveType.valueOf(inspector.getTypeName()); Type type = hiveType.getType(TYPE_MANAGER); - - columnsBuilder.add(new HiveColumnHandle(testColumn.getName(), hiveType, type.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty())); + columnsBuilder.add(new HiveColumnHandle(testColumn.getName(), hiveType, type.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty(), Optional.empty())); typesBuilder.add(type); } columns = columnsBuilder.build(); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java index da6aeaf4306f..57d5f539dbfd 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java @@ -48,12 +48,12 @@ public class TestS3SelectRecordCursor { private static final String LAZY_SERDE_CLASS_NAME = LazySimpleSerDe.class.getName(); - private static final HiveColumnHandle ARTICLE_COLUMN = new HiveColumnHandle("article", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle AUTHOR_COLUMN = new HiveColumnHandle("author", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle DATE_ARTICLE_COLUMN = new HiveColumnHandle("date_pub", HIVE_INT, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle QUANTITY_COLUMN = new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), 1, REGULAR, Optional.empty()); + private static final HiveColumnHandle ARTICLE_COLUMN = new HiveColumnHandle("article", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle AUTHOR_COLUMN = new HiveColumnHandle("author", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle DATE_ARTICLE_COLUMN = new HiveColumnHandle("date_pub", HIVE_INT, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle QUANTITY_COLUMN = new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), 1, REGULAR, Optional.empty(), Optional.empty()); private static final HiveColumnHandle[] DEFAULT_TEST_COLUMNS = {ARTICLE_COLUMN, AUTHOR_COLUMN, DATE_ARTICLE_COLUMN, QUANTITY_COLUMN}; - private static final HiveColumnHandle MOCK_HIVE_COLUMN_HANDLE = new HiveColumnHandle("mockName", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty()); + private static final HiveColumnHandle MOCK_HIVE_COLUMN_HANDLE = new HiveColumnHandle("mockName", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty(), Optional.empty()); private static final TypeManager MOCK_TYPE_MANAGER = new TestingTypeManager(); private static final Path MOCK_PATH = new Path("mockPath"); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java index 44e463056eca..53ad25fef5c1 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java @@ -13,6 +13,7 @@ */ package io.prestosql.plugin.hive.benchmark; +import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; import io.airlift.slice.OutputStreamSliceOutput; import io.prestosql.orc.OrcWriter; @@ -43,6 +44,7 @@ import io.prestosql.rcfile.RcFileWriter; import io.prestosql.rcfile.binary.BinaryRcFileEncoding; import io.prestosql.rcfile.text.TextRcFileEncoding; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.Page; import io.prestosql.spi.connector.ConnectorPageSource; import io.prestosql.spi.connector.ConnectorSession; @@ -354,7 +356,7 @@ private static ConnectorPageSource createPageSource( for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); Type columnType = columnTypes.get(i); - columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty())); + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); } RecordCursor recordCursor = cursorProvider @@ -388,7 +390,15 @@ private static ConnectorPageSource createPageSource( for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); Type columnType = columnTypes.get(i); - columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty())); + + if (columnName.contains(".")) { + Splitter splitter = Splitter.on('.').trimResults().omitEmptyStrings(); + List names = splitter.splitToList(columnName); + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.of(new NestedColumn(names)))); + } + else { + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); + } } return pageSourceFactory diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java index 3e40d0bc00e0..6a899c699903 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java @@ -18,9 +18,11 @@ import com.google.common.collect.ContiguousSet; import com.google.common.collect.DiscreteDomain; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import com.google.common.collect.Range; import com.google.common.primitives.Shorts; import io.airlift.units.DataSize; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.type.ArrayType; import io.prestosql.spi.type.RowType; import io.prestosql.spi.type.SqlDate; @@ -533,6 +535,33 @@ public void testNestedStructs() tester.testRoundTrip(objectInspector, values, values, type); } + @Test + public void testSingleNestedColumn() + throws Exception + { + int nestingLevel = ThreadLocalRandom.current().nextInt(1, 15); + Optional> structFieldNames = Optional.of(singletonList("structField")); + + Iterable values = createTestStructs(limit(cycle(asList(1, null, 3, null, 5, null, 7, null, null, null, 11, null, 13)), 3_210)); + Iterable readValues = Lists.newArrayList(values); + + ObjectInspector objectInspector = getStandardStructObjectInspector(structFieldNames.get(), singletonList(javaIntObjectInspector)); + Type type = RowType.from(singletonList(field("structField", INTEGER))); + + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add("root"); + + for (int i = 0; i < nestingLevel; i++) { + int n = ThreadLocalRandom.current().nextInt(2, 5); + values = insertNullEvery(n, createTestStructs(values)); + readValues = insertNullEvery(n, readValues); + objectInspector = getStandardStructObjectInspector(structFieldNames.get(), singletonList(objectInspector)); + builder.add("structField"); + } + tester.testNestedColumnRoundTrip(singletonList(objectInspector), new Iterable[] {values}, + new Iterable[] {readValues}, ImmutableList.of(new NestedColumn(builder.build())), singletonList(type), Optional.empty(), false); + } + @Test public void testComplexNestedStructs() throws Exception diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java index 83ae6ec7a9eb..3d0e6e843997 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java @@ -32,6 +32,7 @@ import io.prestosql.plugin.hive.parquet.write.SingleLevelArrayMapKeyValuesSchemaConverter; import io.prestosql.plugin.hive.parquet.write.SingleLevelArraySchemaConverter; import io.prestosql.plugin.hive.parquet.write.TestMapredParquetOutputFormat; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ConnectorPageSource; @@ -80,6 +81,7 @@ import java.util.Optional; import java.util.Properties; import java.util.Set; +import java.util.stream.Collectors; import static com.google.common.base.Functions.constant; import static com.google.common.collect.Iterables.transform; @@ -190,6 +192,19 @@ public void testRoundTrip(ObjectInspector objectInspector, Iterable writeValu } } + public void testNestedColumnRoundTrip(List objectInspectors, Iterable[] writeValues, Iterable[] readValues, List nestedColumns, List columnTypes, Optional parquetSchema, boolean singleLevelArray) + throws Exception + { + List columnNames = nestedColumns.stream().map(NestedColumn::getName).collect(Collectors.toList()); + List rootNames = nestedColumns.stream().map(NestedColumn::getBase).collect(Collectors.toList()); + + // just the values + testRoundTripType(objectInspectors, writeValues, readValues, columnNames, columnTypes, Optional.of(rootNames), parquetSchema, singleLevelArray); + + // all nulls + assertRoundTrip(objectInspectors, transformToNulls(writeValues), transformToNulls(readValues), columnNames, columnTypes, Optional.of(rootNames), parquetSchema, singleLevelArray); + } + public void testRoundTrip(ObjectInspector objectInspector, Iterable writeValues, Iterable readValues, Type type, Optional parquetSchema) throws Exception { @@ -247,18 +262,32 @@ private void testRoundTripType( Optional parquetSchema, boolean singleLevelArray) throws Exception + { + testRoundTripType(objectInspectors, writeValues, readValues, columnNames, columnTypes, Optional.empty(), parquetSchema, singleLevelArray); + } + + private void testRoundTripType( + List objectInspectors, + Iterable[] writeValues, + Iterable[] readValues, + List columnNames, + List columnTypes, + Optional> rootColumns, + Optional parquetSchema, + boolean singleLevelArray) + throws Exception { // forward order - assertRoundTrip(objectInspectors, writeValues, readValues, columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, writeValues, readValues, columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); // reverse order - assertRoundTrip(objectInspectors, reverse(writeValues), reverse(readValues), columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, reverse(writeValues), reverse(readValues), columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); // forward order with nulls - assertRoundTrip(objectInspectors, insertNullEvery(5, writeValues), insertNullEvery(5, readValues), columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, insertNullEvery(5, writeValues), insertNullEvery(5, readValues), columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); // reverse order with nulls - assertRoundTrip(objectInspectors, insertNullEvery(5, reverse(writeValues)), insertNullEvery(5, reverse(readValues)), columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, insertNullEvery(5, reverse(writeValues)), insertNullEvery(5, reverse(readValues)), columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); } void assertRoundTrip( @@ -282,6 +311,19 @@ void assertRoundTrip( Optional parquetSchema, boolean singleLevelArray) throws Exception + { + assertRoundTrip(objectInspectors, writeValues, readValues, columnNames, columnTypes, Optional.empty(), parquetSchema, singleLevelArray); + } + + void assertRoundTrip(List objectInspectors, + Iterable[] writeValues, + Iterable[] readValues, + List columnNames, + List columnTypes, + Optional> rootColumns, + Optional parquetSchema, + boolean singleLevelArray) + throws Exception { for (WriterVersion version : versions) { for (CompressionCodecName compressionCodecName : compressions) { @@ -291,12 +333,13 @@ void assertRoundTrip( jobConf.setEnum(COMPRESSION, compressionCodecName); jobConf.setBoolean(ENABLE_DICTIONARY, true); jobConf.setEnum(WRITER_VERSION, version); + List writeColumns = rootColumns.isPresent() ? rootColumns.get() : columnNames; writeParquetColumn( jobConf, tempFile.getFile(), compressionCodecName, - createTableProperties(columnNames, objectInspectors), - getStandardStructObjectInspector(columnNames, objectInspectors), + createTableProperties(writeColumns, objectInspectors), + getStandardStructObjectInspector(writeColumns, objectInspectors), getIterators(writeValues), parquetSchema, singleLevelArray); @@ -601,7 +644,7 @@ private Iterable[] transformToNulls(Iterable[] values) private static Iterable[] reverse(Iterable[] iterables) { return stream(iterables) - .map(ImmutableList::copyOf) + .map(Lists::newArrayList) .map(Lists::reverse) .toArray(size -> new Iterable[size]); } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java index 4c866befc501..ee89b6502f92 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -89,7 +89,7 @@ public void testDictionaryEncodingCasesV1() @Test public void testParquetTupleDomainPrimitiveArray() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_array", HiveType.valueOf("array"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_array", HiveType.valueOf("array"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty(), Optional.empty()); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(new ArrayType(INTEGER)))); MessageType fileSchema = new MessageType("hive_schema", @@ -104,7 +104,7 @@ public void testParquetTupleDomainPrimitiveArray() @Test public void testParquetTupleDomainStructArray() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_array_struct", HiveType.valueOf("array>"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_array_struct", HiveType.valueOf("array>"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty(), Optional.empty()); RowType.Field rowField = new RowType.Field(Optional.of("a"), INTEGER); RowType rowType = RowType.from(ImmutableList.of(rowField)); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(new ArrayType(rowType)))); @@ -122,7 +122,7 @@ public void testParquetTupleDomainStructArray() @Test public void testParquetTupleDomainPrimitive() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_primitive", HiveType.valueOf("bigint"), parseTypeSignature(StandardTypes.BIGINT), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_primitive", HiveType.valueOf("bigint"), parseTypeSignature(StandardTypes.BIGINT), 0, REGULAR, Optional.empty(), Optional.empty()); Domain singleValueDomain = Domain.singleValue(BIGINT, 123L); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, singleValueDomain)); @@ -143,7 +143,7 @@ public void testParquetTupleDomainPrimitive() @Test public void testParquetTupleDomainStruct() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_struct", HiveType.valueOf("struct"), parseTypeSignature(StandardTypes.ROW), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_struct", HiveType.valueOf("struct"), parseTypeSignature(StandardTypes.ROW), 0, REGULAR, Optional.empty(), Optional.empty()); RowType.Field rowField = new RowType.Field(Optional.of("my_struct"), INTEGER); RowType rowType = RowType.from(ImmutableList.of(rowField)); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(rowType))); @@ -160,7 +160,7 @@ public void testParquetTupleDomainStruct() @Test public void testParquetTupleDomainMap() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_map", HiveType.valueOf("map"), parseTypeSignature(StandardTypes.MAP), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_map", HiveType.valueOf("map"), parseTypeSignature(StandardTypes.MAP), 0, REGULAR, Optional.empty(), Optional.empty()); MapType mapType = new MapType( INTEGER, diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java index 2145c0e77709..d26309d7ad00 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java @@ -94,8 +94,8 @@ public class TestMetastoreHiveStatisticsProvider private static final String COLUMN = "column"; private static final DecimalType DECIMAL = createDecimalType(5, 3); - private static final HiveColumnHandle PARTITION_COLUMN_1 = new HiveColumnHandle("p1", HIVE_STRING, VARCHAR.getTypeSignature(), 0, PARTITION_KEY, Optional.empty()); - private static final HiveColumnHandle PARTITION_COLUMN_2 = new HiveColumnHandle("p2", HIVE_LONG, BIGINT.getTypeSignature(), 1, PARTITION_KEY, Optional.empty()); + private static final HiveColumnHandle PARTITION_COLUMN_1 = new HiveColumnHandle("p1", HIVE_STRING, VARCHAR.getTypeSignature(), 0, PARTITION_KEY, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle PARTITION_COLUMN_2 = new HiveColumnHandle("p2", HIVE_LONG, BIGINT.getTypeSignature(), 1, PARTITION_KEY, Optional.empty(), Optional.empty()); @Test public void testGetPartitionsSample() @@ -611,7 +611,7 @@ public void testGetTableStatistics() .build(); MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((table, hivePartitions) -> ImmutableMap.of(partitionName, statistics)); TestingConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(new HiveClientConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties()); - HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty(), Optional.empty()); TableStatistics expected = TableStatistics.builder() .setRowCount(Estimate.of(1000)) .setColumnStatistics( @@ -661,7 +661,7 @@ public void testGetTableStatisticsUnpartitioned() .build(); MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((table, hivePartitions) -> ImmutableMap.of(UNPARTITIONED_ID, statistics)); TestingConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(new HiveClientConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties()); - HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty(), Optional.empty()); TableStatistics expected = TableStatistics.builder() .setRowCount(Estimate.of(1000)) .setColumnStatistics( diff --git a/presto-main/etc/catalog/hive.properties b/presto-main/etc/catalog/hive.properties index 9060aff2a6a7..462c16ee537f 100644 --- a/presto-main/etc/catalog/hive.properties +++ b/presto-main/etc/catalog/hive.properties @@ -7,3 +7,5 @@ connector.name=hive-hadoop2 hive.metastore.uri=thrift://localhost:9083 + +hive.parquet.use-column-names=true diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index d8f505f83d1f..c212b7a13522 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.prestosql.Session; import io.prestosql.connector.ConnectorId; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.BlockEncodingSerde; import io.prestosql.spi.connector.CatalogSchemaName; @@ -117,6 +118,13 @@ public interface Metadata */ Map getColumnHandles(Session session, TableHandle tableHandle); + /** + * Gets all of the columns on the specified table, or an empty map if the columns can not be enumerated. + * + * @throws RuntimeException if table handle is no longer valid + */ + Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection dereferences); + /** * Gets the metadata for the specified table column. * diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 411ae09220f9..cb4a4102ed4e 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -27,6 +27,7 @@ import io.prestosql.Session; import io.prestosql.block.BlockEncodingManager; import io.prestosql.connector.ConnectorId; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.QueryId; import io.prestosql.spi.block.BlockEncodingSerde; @@ -478,6 +479,14 @@ public Map getColumnHandles(Session session, TableHandle t return map.build(); } + @Override + public Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection dereferences) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + return metadata.getNestedColumnHandles(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), dereferences); + } + @Override public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 38ec9e8bda3c..b29f005597d0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -113,12 +113,14 @@ import io.prestosql.sql.planner.optimizations.ImplementIntersectAndExceptAsUnion; import io.prestosql.sql.planner.optimizations.IndexJoinOptimizer; import io.prestosql.sql.planner.optimizations.LimitPushDown; +import io.prestosql.sql.planner.optimizations.MergeNestedColumn; import io.prestosql.sql.planner.optimizations.MetadataDeleteOptimizer; import io.prestosql.sql.planner.optimizations.MetadataQueryOptimizer; import io.prestosql.sql.planner.optimizations.OptimizeMixedDistinctAggregations; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.optimizations.PredicatePushDown; import io.prestosql.sql.planner.optimizations.PruneUnreferencedOutputs; +import io.prestosql.sql.planner.optimizations.PushDownDereferenceExpression; import io.prestosql.sql.planner.optimizations.SetFlatteningOptimizer; import io.prestosql.sql.planner.optimizations.StatsRecordingPlanOptimizer; import io.prestosql.sql.planner.optimizations.TransformQuantifiedComparisonApplyToLateralJoin; @@ -349,6 +351,13 @@ public PlanOptimizers( new RemoveRedundantIdentityProjections(), new TransformCorrelatedSingleRowSubqueryToProject())), new CheckSubqueryNodesAreRewritten(), + + // pushdown dereference + new PushDownDereferenceExpression(metadata, sqlParser), + new PruneUnreferencedOutputs(), + new MergeNestedColumn(metadata, sqlParser), + new IterativeOptimizer(ruleStats, statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PruneTableScanColumns())), + predicatePushDown, new IterativeOptimizer( ruleStats, @@ -398,6 +407,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new EliminateCrossJoins())), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again + predicatePushDown, simplifyOptimizer, // Should be always run after PredicatePushDown new IterativeOptimizer( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index 5ae6bad95514..c5f85fc30380 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -156,6 +156,7 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN .filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node .filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics .filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever + //.filter(entry -> !(child.getAssignments().get(entry.getKey()) instanceof DereferenceExpression)) // skip dereference expression .map(Map.Entry::getKey) .collect(toSet()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MergeNestedColumn.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MergeNestedColumn.java new file mode 100644 index 000000000000..7c2e1462160a --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MergeNestedColumn.java @@ -0,0 +1,316 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.optimizations; + +import com.google.common.base.Preconditions; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.Session; +import io.prestosql.execution.warnings.WarningCollector; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.TableHandle; +import io.prestosql.spi.NestedColumn; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.analyzer.ExpressionAnalyzer; +import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SimplePlanRewriter; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.ExpressionRewriter; +import io.prestosql.sql.tree.ExpressionTreeRewriter; +import io.prestosql.sql.tree.Identifier; +import io.prestosql.sql.tree.NodeRef; +import io.prestosql.sql.tree.SubscriptExpression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; + +public class MergeNestedColumn + implements PlanOptimizer +{ + Metadata metadata; + SqlParser sqlParser; + + public MergeNestedColumn(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + + @Override + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + return SimplePlanRewriter.rewriteWith(new Optimizer(session, symbolAllocator, idAllocator, metadata, sqlParser, warningCollector), plan); + } + + public static boolean prefixExist(Expression expression, final Set allDereferences) + { + int[] referenceCount = {0}; + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, int[] referenceCount) + { + if (allDereferences.contains(node)) { + referenceCount[0] += 1; + } + process(node.getBase(), referenceCount); + return null; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, int[] context) + { + if (allDereferences.contains(node)) { + referenceCount[0] += 1; + } + return null; + } + }.process(expression, referenceCount); + return referenceCount[0] > 1; + } + + private static class Optimizer + extends SimplePlanRewriter + { + private final Session session; + private final SymbolAllocator symbolAllocator; + private final PlanNodeIdAllocator idAllocator; + private final Metadata metadata; + private final SqlParser sqlParser; + private final WarningCollector warningCollector; + + public Optimizer(Session session, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, SqlParser sqlParser, WarningCollector warningCollector) + { + this.session = requireNonNull(session); + this.symbolAllocator = requireNonNull(symbolAllocator); + this.idAllocator = requireNonNull(idAllocator); + this.metadata = requireNonNull(metadata); + this.sqlParser = requireNonNull(sqlParser); + this.warningCollector = requireNonNull(warningCollector); + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext context) + { + if (node.getSource() instanceof TableScanNode) { + TableScanNode tableScanNode = (TableScanNode) node.getSource(); + return mergeProjectWithTableScan(node, tableScanNode, context); + } + return context.defaultRewrite(node); + } + + private Type extractType(Expression expression) + { + Map, Type> expressionTypes = ExpressionAnalyzer.getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), warningCollector); + return expressionTypes.get(NodeRef.of(expression)); + } + + public PlanNode mergeProjectWithTableScan(ProjectNode node, TableScanNode tableScanNode, RewriteContext context) + { + Set allExpressions = node.getAssignments().getExpressions().stream().map(MergeNestedColumn::validDereferenceExpression).filter(Objects::nonNull).collect(toImmutableSet()); + Set dereferences = allExpressions.stream() + .filter(expression -> !prefixExist(expression, allExpressions)) + .filter(expression -> expression instanceof DereferenceExpression) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return context.defaultRewrite(node); + } + + NestedColumnTranslator nestedColumnTranslator = new NestedColumnTranslator(tableScanNode.getAssignments(), tableScanNode.getTable()); + Map nestedColumns = dereferences.stream().collect(Collectors.toMap(Function.identity(), nestedColumnTranslator::toNestedColumn)); + + Map nestedColumnHandles = + metadata.getNestedColumnHandles(session, tableScanNode.getTable(), nestedColumns.values()) + .entrySet().stream() + .filter(entry -> !nestedColumnTranslator.columnHandleExists(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (nestedColumnHandles.isEmpty()) { + return context.defaultRewrite(node); + } + + ImmutableMap.Builder columnHandleBuilder = ImmutableMap.builder(); + columnHandleBuilder.putAll(tableScanNode.getAssignments()); + + // Use to replace expression in original dereference expression + ImmutableMap.Builder symbolExpressionBuilder = ImmutableMap.builder(); + for (Map.Entry entry : nestedColumnHandles.entrySet()) { + NestedColumn nestedColumn = entry.getKey(); + Expression expression = nestedColumnTranslator.toExpression(nestedColumn); + Symbol symbol = symbolAllocator.newSymbol(nestedColumn.getName(), extractType(expression)); + symbolExpressionBuilder.put(expression, symbol); + columnHandleBuilder.put(symbol, entry.getValue()); + } + ImmutableMap nestedColumnsMap = columnHandleBuilder.build(); + + TableScanNode newTableScan = new TableScanNode(idAllocator.getNextId(), tableScanNode.getTable(), ImmutableList.copyOf(nestedColumnsMap.keySet()), nestedColumnsMap, tableScanNode.getLayout(), tableScanNode.getCurrentConstraint(), tableScanNode.getEnforcedConstraint()); + + Rewriter rewriter = new Rewriter(symbolExpressionBuilder.build()); + Map assignments = node.getAssignments().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> ExpressionTreeRewriter.rewriteWith(rewriter, entry.getValue()))); + return new ProjectNode(idAllocator.getNextId(), newTableScan, Assignments.copyOf(assignments)); + } + + private class NestedColumnTranslator + { + private final Map symbolToColumnName; + private final Map columnNameToSymbol; + + NestedColumnTranslator(Map columnHandleMap, TableHandle tableHandle) + { + BiMap symbolToColumnName = HashBiMap.create(columnHandleMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> metadata.getColumnMetadata(session, tableHandle, entry.getValue()).getName()))); + this.symbolToColumnName = symbolToColumnName; + this.columnNameToSymbol = symbolToColumnName.inverse(); + } + + boolean columnHandleExists(NestedColumn nestedColumn) + { + return columnNameToSymbol.containsKey(nestedColumn.getName()); + } + + NestedColumn toNestedColumn(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitSubscriptExpression(SubscriptExpression node, Void context) + { + return null; + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void context) + { + process(node.getBase(), context); + builder.add(node.getField().getValue()); + return null; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, Void context) + { + Symbol baseName = Symbol.from(node); + Preconditions.checkArgument(symbolToColumnName.containsKey(baseName), "base [%s] doesn't exist in assignments [%s]", baseName, symbolToColumnName); + builder.add(symbolToColumnName.get(baseName)); + return null; + } + }.process(expression, null); + List names = builder.build(); + Preconditions.checkArgument(names.size() > 1, "names size is less than 0", names); + return new NestedColumn(names); + } + + Expression toExpression(NestedColumn nestedColumn) + { + Expression result = null; + for (String part : nestedColumn.getNames()) { + if (result == null) { + Preconditions.checkArgument(columnNameToSymbol.containsKey(part), "element %s doesn't exist in map %s", part, columnNameToSymbol); + result = columnNameToSymbol.get(part).toSymbolReference(); + } + else { + result = new DereferenceExpression(result, new Identifier(part)); + } + } + return result; + } + } + } + + // expression: msg_12.foo -> nestedColumn: msg.foo -> expression: msg_12.foo + + private static class Rewriter + extends ExpressionRewriter + { + private final Map map; + + Rewriter(Map map) + { + this.map = map; + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node)) { + return map.get(node).toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + + @Override + public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node)) { + return map.get(node).toSymbolReference(); + } + return super.rewriteSymbolReference(node, context, treeRewriter); + } + } + + public static Expression validDereferenceExpression(Expression expression) + { + //Preconditions.checkArgument(expression instanceof DereferenceExpression, "express must be dereference expression first"); + SubscriptExpression[] shortestSubscriptExp = new SubscriptExpression[1]; + boolean[] valid = new boolean[1]; + valid[0] = true; + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitSubscriptExpression(SubscriptExpression node, Void context) + { + shortestSubscriptExp[0] = node; + process(node.getBase(), context); + return null; + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void context) + { + valid[0] &= (node.getBase() instanceof SymbolReference || node.getBase() instanceof DereferenceExpression || node.getBase() instanceof SubscriptExpression); + process(node.getBase(), context); + return null; + } + }.process(expression, null); + if (valid[0]) { + return shortestSubscriptExp[0] == null ? expression : shortestSubscriptExp[0].getBase(); + } + else { + return null; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PushDownDereferenceExpression.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PushDownDereferenceExpression.java new file mode 100644 index 000000000000..5573be03da74 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PushDownDereferenceExpression.java @@ -0,0 +1,369 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.optimizations; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.prestosql.Session; +import io.prestosql.execution.warnings.WarningCollector; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.parser.SqlParser; +import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.planner.plan.AggregationNode; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SimplePlanRewriter; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.planner.plan.UnnestNode; +import io.prestosql.sql.planner.plan.ValuesNode; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.ExpressionRewriter; +import io.prestosql.sql.tree.ExpressionTreeRewriter; +import io.prestosql.sql.tree.NodeRef; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static io.prestosql.sql.planner.optimizations.MergeNestedColumn.prefixExist; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; + +public class PushDownDereferenceExpression + implements PlanOptimizer +{ + private final Metadata metadata; + private final SqlParser sqlParser; + + public PushDownDereferenceExpression(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlparser is null"); + } + + @Override + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + Map expressionInfoMap = new HashMap<>(); + return SimplePlanRewriter.rewriteWith(new Optimizer(session, metadata, sqlParser, symbolAllocator, idAllocator, warningCollector), plan, expressionInfoMap); + } + + private static class DereferenceReplacer + extends ExpressionRewriter + { + private final Map map; + + DereferenceReplacer(Map map) + { + this.map = map; + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node) && map.get(node).isFromValidSource()) { + return map.get(node).getSymbol().toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + } + + private static class Optimizer + extends SimplePlanRewriter> + { + private final Session session; + private final SqlParser sqlParser; + private final SymbolAllocator symbolAllocator; + private final PlanNodeIdAllocator idAllocator; + private final Metadata metadata; + private final WarningCollector warningCollector; + + private Optimizer(Session session, Metadata metadata, SqlParser sqlParser, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + this.session = session; + this.sqlParser = sqlParser; + this.metadata = metadata; + this.symbolAllocator = symbolAllocator; + this.idAllocator = idAllocator; + this.warningCollector = warningCollector; + } + + @Override + public PlanNode visitAggregation(AggregationNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + extractDereferenceInfos(node).forEach(expressionInfoMap::putIfAbsent); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + Map aggregations = new HashMap<>(); + for (Map.Entry symbolAggregationEntry : node.getAggregations().entrySet()) { + Symbol symbol = symbolAggregationEntry.getKey(); + AggregationNode.Aggregation oldAggregation = symbolAggregationEntry.getValue(); + AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), oldAggregation.getCall()), oldAggregation.getSignature(), oldAggregation.getMask()); + aggregations.put(symbol, newAggregation); + } + return new AggregationNode( + idAllocator.getNextId(), + child, + aggregations, + node.getGroupingSets(), + node.getPreGroupedSymbols(), + node.getStep(), + node.getHashSymbol(), + node.getGroupIdSymbol()); + } + + @Override + public PlanNode visitFilter(FilterNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + extractDereferenceInfos(node).forEach(expressionInfoMap::putIfAbsent); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + Expression predicate = ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), node.getPredicate()); + return new FilterNode(idAllocator.getNextId(), child, predicate); + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + // parentDereferenceInfos is used to find out passThroughSymbol. we will only pass those symbols that are needed by upstream + List parentDereferenceInfos = expressionInfoMap.entrySet().stream().map(Map.Entry::getValue).collect(Collectors.toList()); + Map newDereferences = extractDereferenceInfos(node); + newDereferences.forEach(expressionInfoMap::putIfAbsent); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + List passThroughSymbols = getUsedDereferenceInfo(node.getOutputSymbols(), parentDereferenceInfos).stream().filter(DereferenceInfo::isFromValidSource).map(DereferenceInfo::getSymbol).collect(Collectors.toList()); + + Assignments.Builder assignmentsBuilder = Assignments.builder(); + for (Map.Entry entry : node.getAssignments().entrySet()) { + assignmentsBuilder.put(entry.getKey(), ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), entry.getValue())); + } + assignmentsBuilder.putIdentities(passThroughSymbols); + ProjectNode newProjectNode = new ProjectNode(idAllocator.getNextId(), child, assignmentsBuilder.build()); + newDereferences.forEach(expressionInfoMap::remove); + return newProjectNode; + } + + @Override + public PlanNode visitTableScan(TableScanNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + List usedDereferenceInfo = getUsedDereferenceInfo(node.getOutputSymbols(), expressionInfoMap.values()); + if (!usedDereferenceInfo.isEmpty()) { + usedDereferenceInfo.forEach(DereferenceInfo::doesFromValidSource); + Map assignmentMap = usedDereferenceInfo.stream().collect(Collectors.toMap(DereferenceInfo::getSymbol, DereferenceInfo::getDereference)); + return new ProjectNode(idAllocator.getNextId(), node, Assignments.builder().putAll(assignmentMap).putIdentities(node.getOutputSymbols()).build()); + } + return node; + } + + @Override + public PlanNode visitValues(ValuesNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + List usedDereferenceInfo = getUsedDereferenceInfo(node.getOutputSymbols(), expressionInfoMap.values()); + if (!usedDereferenceInfo.isEmpty()) { + usedDereferenceInfo.forEach(DereferenceInfo::doesFromValidSource); + Map assignmentMap = usedDereferenceInfo.stream().collect(Collectors.toMap(DereferenceInfo::getSymbol, DereferenceInfo::getDereference)); + return new ProjectNode(idAllocator.getNextId(), node, Assignments.builder().putAll(assignmentMap).putIdentities(node.getOutputSymbols()).build()); + } + return node; + } + + @Override + public PlanNode visitJoin(JoinNode joinNode, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + extractDereferenceInfos(joinNode).forEach(expressionInfoMap::putIfAbsent); + + PlanNode leftNode = context.rewrite(joinNode.getLeft(), expressionInfoMap); + PlanNode rightNode = context.rewrite(joinNode.getRight(), expressionInfoMap); + + List equiJoinClauses = joinNode.getCriteria().stream() + .map(JoinNode.EquiJoinClause::toExpression) + .map(expr -> ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), expr)) + .map(this::getEquiJoinClause) + .collect(Collectors.toList()); + + Optional joinFilter = joinNode.getFilter().map(expression -> ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), expression)); + + return new JoinNode( + joinNode.getId(), + joinNode.getType(), + leftNode, + rightNode, + equiJoinClauses, + ImmutableList.builder().addAll(leftNode.getOutputSymbols()).addAll(rightNode.getOutputSymbols()).build(), + joinFilter, + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType()); + } + + @Override + public PlanNode visitUnnest(UnnestNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + List parentDereferenceInfos = expressionInfoMap.entrySet().stream().map(Map.Entry::getValue).collect(Collectors.toList()); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + List passThroughSymbols = getUsedDereferenceInfo(child.getOutputSymbols(), parentDereferenceInfos).stream().filter(DereferenceInfo::isFromValidSource).map(DereferenceInfo::getSymbol).collect(Collectors.toList()); + UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), child, ImmutableList.builder().addAll(node.getReplicateSymbols()).addAll(passThroughSymbols).build(), node.getUnnestSymbols(), node.getOrdinalitySymbol()); + + List unnestSymbols = unnestNode.getUnnestSymbols().entrySet().stream().flatMap(entry -> entry.getValue().stream()).collect(Collectors.toList()); + List dereferenceExpressionInfos = getUsedDereferenceInfo(unnestSymbols, expressionInfoMap.values()); + if (!dereferenceExpressionInfos.isEmpty()) { + dereferenceExpressionInfos.forEach(DereferenceInfo::doesFromValidSource); + Map assignmentMap = dereferenceExpressionInfos.stream().collect(Collectors.toMap(DereferenceInfo::getSymbol, DereferenceInfo::getDereference)); + return new ProjectNode(idAllocator.getNextId(), unnestNode, Assignments.builder().putAll(assignmentMap).putIdentities(unnestNode.getOutputSymbols()).build()); + } + return unnestNode; + } + + private List getUsedDereferenceInfo(List symbols, Collection dereferenceExpressionInfos) + { + Set symbolSet = symbols.stream().collect(Collectors.toSet()); + return dereferenceExpressionInfos.stream().filter(dereferenceExpressionInfo -> symbolSet.contains(dereferenceExpressionInfo.getBaseSymbol())).collect(Collectors.toList()); + } + + private JoinNode.EquiJoinClause getEquiJoinClause(Expression expression) + { + checkArgument(expression instanceof ComparisonExpression, "expression [%s] is not equal expression", expression); + ComparisonExpression comparisonExpression = (ComparisonExpression) expression; + return new JoinNode.EquiJoinClause(Symbol.from(comparisonExpression.getLeft()), Symbol.from(comparisonExpression.getRight())); + } + + private Type extractType(Expression expression) + { + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), warningCollector); + return expressionTypes.get(NodeRef.of(expression)); + } + + private DereferenceInfo getDereferenceInfo(Expression expression) + { + Symbol symbol = symbolAllocator.newSymbol(expression, extractType(expression)); + Symbol base = Iterables.getOnlyElement(SymbolsExtractor.extractAll(expression)); + return new DereferenceInfo(expression, symbol, base); + } + + private List extractDereference(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor>() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableList.Builder context) + { + context.add(node); + return null; + } + }.process(expression, builder); + return builder.build(); + } + + private Map extractDereferenceInfos(PlanNode node) + { + Set allExpressions = ExpressionExtractor.extractExpressionsNonRecursive(node).stream() + .flatMap(expression -> extractDereference(expression).stream()) + .map(MergeNestedColumn::validDereferenceExpression).filter(Objects::nonNull).collect(toImmutableSet()); + + return allExpressions.stream() + .filter(expression -> !prefixExist(expression, allExpressions)) + .filter(expression -> expression instanceof DereferenceExpression) + .distinct() + .map(this::getDereferenceInfo) + .collect(Collectors.toMap(DereferenceInfo::getDereference, Function.identity())); + } + } + + private static class DereferenceInfo + { + // e.g. for dereference expression msg.foo[1].bar, base is "msg", newSymbol is new assigned symbol to replace this dereference expression + private final Expression dereferenceExpression; + private final Symbol symbol; + private final Symbol baseSymbol; + + // fromValidSource is used to check whether the dereference expression is from either TableScan or Unnest + // it will be false for following node therefore we won't rewrite: + // Project[expr_1 := "max_by"."field1"] + // - Aggregate[max_by := "max_by"("expr", "app_rating")] => [max_by:row(field0 varchar, field1 varchar)] + private boolean fromValidSource; + + public DereferenceInfo(Expression dereferenceExpression, Symbol symbol, Symbol baseSymbol) + { + this.dereferenceExpression = requireNonNull(dereferenceExpression); + this.symbol = requireNonNull(symbol); + this.baseSymbol = requireNonNull(baseSymbol); + this.fromValidSource = false; + } + + public Symbol getSymbol() + { + return symbol; + } + + public Symbol getBaseSymbol() + { + return baseSymbol; + } + + public Expression getDereference() + { + return dereferenceExpression; + } + + public boolean isFromValidSource() + { + return fromValidSource; + } + + public void doesFromValidSource() + { + fromValidSource = true; + } + + @Override + public String toString() + { + return String.format("(%s, %s, %s)", dereferenceExpression, symbol, baseSymbol); + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java b/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java index 5f6a8e1bf6e7..375ca1711ac5 100644 --- a/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java +++ b/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java @@ -54,6 +54,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.spi.StandardErrorCode.ALREADY_EXISTS; import static java.util.Objects.requireNonNull; @@ -64,6 +65,16 @@ public class TestingMetadata private final ConcurrentMap tables = new ConcurrentHashMap<>(); private final ConcurrentMap views = new ConcurrentHashMap<>(); + protected ConcurrentMap getTables() + { + return tables; + } + + protected ConcurrentMap getViews() + { + return views; + } + @Override public List listSchemaNames(ConnectorSession session) { @@ -290,7 +301,7 @@ public void clear() tables.clear(); } - private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) + protected static SchemaTableName getTableName(ConnectorTableHandle tableHandle) { requireNonNull(tableHandle, "tableHandle is null"); checkArgument(tableHandle instanceof TestingTableHandle, "tableHandle is not an instance of TestingTableHandle"); @@ -397,5 +408,11 @@ public int hashCode() { return Objects.hash(name, ordinalPosition, type); } + + @Override + public String toString() + { + return toStringHelper(this).add("name", name).add("position", ordinalPosition).add("type", type).toString(); + } } } diff --git a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java index 5ce0beed4a7f..64b0740f9247 100644 --- a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.prestosql.Session; import io.prestosql.connector.ConnectorId; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.block.BlockEncodingSerde; import io.prestosql.spi.connector.CatalogSchemaName; import io.prestosql.spi.connector.ColumnHandle; @@ -167,6 +168,12 @@ public Map getColumnHandles(Session session, TableHandle t throw new UnsupportedOperationException(); } + @Override + public Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection dereferences) + { + throw new UnsupportedOperationException(); + } + @Override public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java new file mode 100644 index 000000000000..0cb916e114a5 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; + +public class TestDereferencePushDown + extends BasePlanTest +{ + private static final String VALUES = "(values ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))"; + + @Test + public void testPushDownDereferencesThroughJoin() + { + assertPlan(" with t1 as ( select * from " + VALUES + " as t (msg) ) select b.msg.x from t1 a, t1 b where a.msg.y = b.msg.y", + output(ImmutableList.of("x"), + join(INNER, ImmutableList.of(equiJoinClause("left_y", "right_y")), + anyTree( + project(ImmutableMap.of("left_y", expression("field.y")), + values("field")) + ), anyTree( + project(ImmutableMap.of("right_y", expression("field1.y"), "x", expression("field1.x")), + values("field1")))))); + } + + @Test + public void testPushDownDereferencesInCase() + { + // Test dereferences in then clause will not be eagerly evaluated. + String statement = "with t as (select * from (values cast(array[CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)),ROW(1, 2.0)] as array)) as t (arr) ) " + + "select case when cardinality(arr) > cast(0 as bigint) then arr[cast(1 as bigint)].x end from t"; + assertPlan(statement, + output(ImmutableList.of("x"), + project(ImmutableMap.of("x", expression("case when cardinality(field) > bigint '0' then field[bigint '1'].x end")), values("field")))); + } + + @Test + public void testPushDownDereferencesThroughFilter() + { + assertPlan(" with t1 as ( select * from " + VALUES + " as t (msg) ) select a.msg.y from t1 a join t1 b on a.msg.y = b.msg.y where a.msg.x > bigint '5'", + output(ImmutableList.of("left_y"), + join(INNER, ImmutableList.of(equiJoinClause("left_y", "right_y")), + anyTree( + project(ImmutableMap.of("left_y", expression("field.y")), + filter("field.x > bigint '5'", values("field"))) + ), anyTree( + project(ImmutableMap.of("right_y", expression("field1.y")), + values("field1")))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestMergeNestedColumns.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestMergeNestedColumns.java new file mode 100644 index 000000000000..7d31e4320cd7 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestMergeNestedColumns.java @@ -0,0 +1,264 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.NestedColumn; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.ColumnMetadata; +import io.prestosql.spi.connector.Connector; +import io.prestosql.spi.connector.ConnectorContext; +import io.prestosql.spi.connector.ConnectorFactory; +import io.prestosql.spi.connector.ConnectorHandleResolver; +import io.prestosql.spi.connector.ConnectorMetadata; +import io.prestosql.spi.connector.ConnectorPageSinkProvider; +import io.prestosql.spi.connector.ConnectorPageSourceProvider; +import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.connector.ConnectorSplitManager; +import io.prestosql.spi.connector.ConnectorTableHandle; +import io.prestosql.spi.connector.ConnectorTableLayout; +import io.prestosql.spi.connector.ConnectorTableLayoutHandle; +import io.prestosql.spi.connector.ConnectorTableLayoutResult; +import io.prestosql.spi.connector.ConnectorTableMetadata; +import io.prestosql.spi.connector.ConnectorTransactionHandle; +import io.prestosql.spi.connector.Constraint; +import io.prestosql.spi.connector.FixedPageSource; +import io.prestosql.spi.connector.SchemaTableName; +import io.prestosql.spi.transaction.IsolationLevel; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.VarcharType; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import io.prestosql.testing.LocalQueryRunner; +import io.prestosql.testing.TestingHandleResolver; +import io.prestosql.testing.TestingMetadata; +import io.prestosql.testing.TestingPageSinkProvider; +import io.prestosql.testing.TestingSplitManager; +import org.testng.annotations.Test; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static io.prestosql.spi.type.RowType.field; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.prestosql.testing.TestingSession.testSessionBuilder; +import static java.util.Objects.requireNonNull; + +public class TestMergeNestedColumns + extends BasePlanTest +{ + private static final Type MSG_TYPE = RowType.from(ImmutableList.of(field("x", VarcharType.VARCHAR), field("y", VarcharType.VARCHAR))); + + public TestMergeNestedColumns() + { + super(TestMergeNestedColumns::createQueryRunner); + } + + @Test + public void testSelectDereference() + { + assertPlan("select foo.x, foo.y, bar.x, bar.y from nested_column_table", + output(ImmutableList.of("foo_x", "foo_y", "bar_x", "bar_y"), + project(ImmutableMap.of("foo_x", expression("foo_x"), "foo_y", expression("foo_y"), "bar_x", expression("bar.x"), "bar_y", expression("bar.y")), + tableScan("nested_column_table", ImmutableMap.of("foo_x", "foo.x", "foo_y", "foo.y", "bar", "bar"))))); + } + + @Test + public void testSelectDereferenceAndParentDoesNotFire() + { + assertPlan("select foo.x, foo.y, foo from nested_column_table", + output(ImmutableList.of("foo_x", "foo_y", "foo"), + project(ImmutableMap.of("foo_x", expression("foo.x"), "foo_y", expression("foo.y"), "foo", expression("foo")), + tableScan("nested_column_table", ImmutableMap.of("foo", "foo"))))); + } + + private static LocalQueryRunner createQueryRunner() + { + String schemaName = "test-schema"; + String catalogName = "test"; + TableInfo regularTable = new TableInfo(new SchemaTableName(schemaName, "regular_table"), + ImmutableList.of(new ColumnMetadata("dummy_column", VarcharType.VARCHAR)), ImmutableMap.of()); + + TableInfo nestedColumnTable = new TableInfo(new SchemaTableName(schemaName, "nested_column_table"), + ImmutableList.of(new ColumnMetadata("foo", MSG_TYPE), new ColumnMetadata("bar", MSG_TYPE)), ImmutableMap.of( + new NestedColumn(ImmutableList.of("foo", "x")), 0, + new NestedColumn(ImmutableList.of("foo", "y")), 0)); + + ImmutableList tableInfos = ImmutableList.of(regularTable, nestedColumnTable); + + LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder() + .setCatalog(catalogName) + .setSchema(schemaName) + .build()); + queryRunner.createCatalog(catalogName, new TestConnectorFactory(new TestMetadata(tableInfos)), ImmutableMap.of()); + return queryRunner; + } + + private static class TestConnectorFactory + implements ConnectorFactory + { + private final TestMetadata metadata; + + public TestConnectorFactory(TestMetadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public String getName() + { + return "test"; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new TestingHandleResolver(); + } + + @Override + public Connector create(String connectorId, Map config, ConnectorContext context) + { + return new TestConnector(metadata); + } + } + + private enum TransactionInstance + implements ConnectorTransactionHandle + { + INSTANCE + } + + private static class TestConnector + implements Connector + { + private final ConnectorMetadata metadata; + + private TestConnector(ConnectorMetadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return TransactionInstance.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return new TestingSplitManager(ImmutableList.of()); + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return (transactionHandle, session, split, columns) -> new FixedPageSource(ImmutableList.of()); + } + + @Override + public ConnectorPageSinkProvider getPageSinkProvider() + { + return new TestingPageSinkProvider(); + } + } + + private static class TestMetadata + extends TestingMetadata + { + private final List tableInfos; + + TestMetadata(List tableInfos) + { + this.tableInfos = requireNonNull(tableInfos, "tableinfos is null"); + insertTables(); + } + + private void insertTables() + { + for (TableInfo tableInfo : tableInfos) { + getTables().put(tableInfo.getSchemaTableName(), tableInfo.getTableMetadata()); + } + } + + @Override + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection dereferences) + { + requireNonNull(tableHandle, "tableHandle is null"); + SchemaTableName tableName = getTableName(tableHandle); + return tableInfos.stream().filter(tableInfo -> tableInfo.getSchemaTableName().equals(tableName)).map(TableInfo::getNestedColumnHandle).findFirst().orElse(ImmutableMap.of()); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + return new ConnectorTableLayout(new ConnectorTableLayoutHandle() {}); + } + + @Override + public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + { + return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(new ConnectorTableLayoutHandle() {}), constraint.getSummary())); + } + } + + private static class TableInfo + { + private final SchemaTableName schemaTableName; + private final List columnMetadatas; + private final Map nestedColumns; + + public TableInfo(SchemaTableName schemaTableName, List columnMetadata, Map nestedColumns) + { + this.schemaTableName = schemaTableName; + this.columnMetadatas = columnMetadata; + this.nestedColumns = nestedColumns; + } + + SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + ConnectorTableMetadata getTableMetadata() + { + return new ConnectorTableMetadata(schemaTableName, columnMetadatas); + } + + Map getNestedColumnHandle() + { + return nestedColumns.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { + Preconditions.checkArgument(entry.getValue() >= 0 && entry.getValue() < columnMetadatas.size(), "index is not valid"); + NestedColumn nestedColumn = entry.getKey(); + return new TestingMetadata.TestingColumnHandle(nestedColumn.getName(), entry.getValue(), VarcharType.VARCHAR); + })); + } + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java index 95e2796c753d..6976918f623d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java @@ -13,6 +13,8 @@ */ package io.prestosql.sql.planner.assertions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import io.prestosql.Session; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; @@ -23,6 +25,7 @@ import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.TableScanNode; +import java.util.AbstractMap; import java.util.Map; import java.util.Optional; @@ -91,8 +94,8 @@ private Optional getAssignedSymbol(Map assignments private Optional getColumnHandle(TableHandle tableHandle, Session session, Metadata metadata) { - return metadata.getColumnHandles(session, tableHandle).entrySet() - .stream() + return Streams.concat(metadata.getColumnHandles(session, tableHandle).entrySet().stream(), + metadata.getNestedColumnHandles(session, tableHandle, ImmutableList.of()).entrySet().stream().map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey().getName(), entry.getValue()))) .filter(entry -> columnName.equals(entry.getKey())) .map(Map.Entry::getValue) .findFirst(); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java index 00fa3d9f9c04..bf60645ba5da 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java @@ -21,6 +21,7 @@ import io.prestosql.sql.tree.CoalesceExpression; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.DecimalLiteral; +import io.prestosql.sql.tree.DereferenceExpression; import io.prestosql.sql.tree.DoubleLiteral; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; @@ -34,8 +35,10 @@ import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NotExpression; import io.prestosql.sql.tree.NullLiteral; +import io.prestosql.sql.tree.SearchedCaseExpression; import io.prestosql.sql.tree.SimpleCaseExpression; import io.prestosql.sql.tree.StringLiteral; +import io.prestosql.sql.tree.SubscriptExpression; import io.prestosql.sql.tree.SymbolReference; import io.prestosql.sql.tree.TryExpression; import io.prestosql.sql.tree.WhenClause; @@ -97,6 +100,20 @@ protected Boolean visitTryExpression(TryExpression actual, Node expected) return process(actual.getInnerExpression(), ((TryExpression) expected).getInnerExpression()); } + @Override + protected Boolean visitDereferenceExpression(DereferenceExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof DereferenceExpression)) { + return false; + } + + DereferenceExpression expected = (DereferenceExpression) expectedExpression; + if (actual.getField().equals(expected.getField())) { + return process(actual.getBase(), expected.getBase()); + } + return false; + } + @Override protected Boolean visitCast(Cast actual, Node expectedExpression) { @@ -125,6 +142,18 @@ protected Boolean visitIsNullPredicate(IsNullPredicate actual, Node expectedExpr return process(actual.getValue(), expected.getValue()); } + @Override + protected Boolean visitSubscriptExpression(SubscriptExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SubscriptExpression)) { + return false; + } + + SubscriptExpression expected = (SubscriptExpression) expectedExpression; + + return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex()); + } + @Override protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Node expectedExpression) { @@ -307,6 +336,17 @@ protected Boolean visitNotExpression(NotExpression actual, Node expected) return false; } + @Override + protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SearchedCaseExpression)) { + return false; + } + + SearchedCaseExpression expected = (SearchedCaseExpression) expectedExpression; + return process(actual.getDefaultValue(), expected.getDefaultValue()) && process(actual.getWhenClauses(), expected.getWhenClauses()); + } + @Override protected Boolean visitSymbolReference(SymbolReference actual, Node expected) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java index 56fd34a61736..c1d4874e0a76 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java @@ -304,6 +304,11 @@ public static PlanMatchPattern strictOutput(List outputs, PlanMatchPatte return output(outputs, source).withExactOutputs(outputs); } + public static PlanMatchPattern unnest(List replicateSymbols, Map> unnestSymbols, Optional ordinalitySymbol, PlanMatchPattern source) + { + return unnest(source).with(new UnnestMatcher(replicateSymbols, unnestSymbols, ordinalitySymbol)); + } + public static PlanMatchPattern project(PlanMatchPattern source) { return node(ProjectNode.class, source); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java new file mode 100644 index 000000000000..18afa93c8e80 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.assertions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.prestosql.Session; +import io.prestosql.cost.StatsProvider; +import io.prestosql.metadata.Metadata; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.UnnestNode; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkState; +import static io.prestosql.sql.planner.assertions.MatchResult.match; +import static java.util.Objects.requireNonNull; + +public class UnnestMatcher + implements Matcher +{ + private final List replicateSymbols; + private final Map> unnestSymbols; + private final Optional ordinalitySymbol; + + public UnnestMatcher(List replicateSymbols, Map> unnestSymbols, Optional ordinalitySymbol) + { + this.replicateSymbols = requireNonNull(replicateSymbols, "replicateSymbols is null"); + this.unnestSymbols = requireNonNull(unnestSymbols); + this.ordinalitySymbol = requireNonNull(ordinalitySymbol); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof UnnestNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + UnnestNode unnestNode = (UnnestNode) node; + + ImmutableList.Builder aliasBuilder = ImmutableList.builder().addAll(replicateSymbols).addAll(Iterables.concat(unnestSymbols.values())); + ordinalitySymbol.ifPresent(aliasBuilder::add); + List alias = aliasBuilder.build(); + + if (alias.size() != unnestNode.getOutputSymbols().size()) { + return MatchResult.NO_MATCH; + } + + SymbolAliases.Builder builder = SymbolAliases.builder().putAll(symbolAliases); + IntStream.range(0, alias.size()).forEach(i -> builder.put(alias.get(i), unnestNode.getOutputSymbols().get(i).toSymbolReference())); + return match(builder.build()); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java index 63fec8bb2689..55bd21a41ded 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java @@ -42,6 +42,12 @@ public class RuleTester private final Metadata metadata; private final Session session; + + public LocalQueryRunner getQueryRunner() + { + return queryRunner; + } + private final LocalQueryRunner queryRunner; private final TransactionManager transactionManager; private final SplitManager splitManager; diff --git a/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java b/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java index 59a4ba55a755..d267608bd182 100644 --- a/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java +++ b/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java @@ -13,6 +13,9 @@ */ package io.prestosql.parquet; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.type.DecimalType; import io.prestosql.spi.type.Type; import org.apache.parquet.column.Encoding; @@ -24,6 +27,7 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.PrimitiveColumnIO; import org.apache.parquet.schema.DecimalMetadata; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import java.util.Arrays; @@ -34,6 +38,7 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterators.getOnlyElement; import static org.apache.parquet.schema.OriginalType.DECIMAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; @@ -185,14 +190,14 @@ public static ParquetEncoding getParquetEncoding(Encoding encoding) } } - public static org.apache.parquet.schema.Type getParquetTypeByName(String columnName, MessageType messageType) + public static org.apache.parquet.schema.Type getParquetTypeByName(String columnName, GroupType groupType) { - if (messageType.containsField(columnName)) { - return messageType.getType(columnName); + if (groupType.containsField(columnName)) { + return groupType.getType(columnName); } // parquet is case-sensitive, but hive is not. all hive columns get converted to lowercase // check for direct match above but if no match found, try case-insensitive match - for (org.apache.parquet.schema.Type type : messageType.getFields()) { + for (org.apache.parquet.schema.Type type : groupType.getFields()) { if (type.getName().equalsIgnoreCase(columnName)) { return type; } @@ -262,4 +267,37 @@ public static long getShortDecimalValue(byte[] bytes) return value; } + + public static org.apache.parquet.schema.Type getNestedColumnType(GroupType baseType, NestedColumn nestedColumn) + { + Preconditions.checkArgument(nestedColumn.getNames().size() >= 1, "fields size is less than 1"); + + ImmutableList.Builder typeBuilder = ImmutableList.builder(); + org.apache.parquet.schema.Type parentType = baseType; + + for (String field : nestedColumn.getNames()) { + org.apache.parquet.schema.Type childType = getParquetTypeByName(field, parentType.asGroupType()); + if (childType == null) { + return null; + } + typeBuilder.add(childType); + parentType = childType; + } + List typeChain = typeBuilder.build(); + + if (typeChain.isEmpty()) { + return null; + } + else if (typeChain.size() == 1) { + return getOnlyElement(typeChain.iterator()); + } + else { + org.apache.parquet.schema.Type messageType = typeChain.get(typeChain.size() - 1); + for (int i = typeChain.size() - 2; i >= 0; --i) { + GroupType groupType = typeChain.get(i).asGroupType(); + messageType = new MessageType(groupType.getName(), ImmutableList.of(messageType)); + } + return messageType; + } + } } diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java b/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java index d5401a8a1b5e..5a812441eafc 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java @@ -20,6 +20,7 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ENGLISH; public class DereferenceExpression extends Expression @@ -118,12 +119,12 @@ public boolean equals(Object o) } DereferenceExpression that = (DereferenceExpression) o; return Objects.equals(base, that.base) && - Objects.equals(field, that.field); + Objects.equals(field.getValue().toLowerCase(ENGLISH), that.field.getValue().toLowerCase(ENGLISH)); } @Override public int hashCode() { - return Objects.hash(base, field); + return Objects.hash(base, field.getValue().toLowerCase(ENGLISH)); } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/NestedColumn.java b/presto-spi/src/main/java/io/prestosql/spi/NestedColumn.java new file mode 100644 index 000000000000..0806d903a8e0 --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/NestedColumn.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.prestosql.spi.type.Type; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static java.util.Objects.requireNonNull; + +public class NestedColumn +{ + private final List names; + + @JsonCreator + public NestedColumn(@JsonProperty("names") List names) + { + this.names = requireNonNull(names); + } + + @JsonProperty + public List getNames() + { + return names; + } + + public String getBase() + { + return names.get(0); + } + + public List getRest() + { + // TODO assert size > 1; + return names.subList(1, names.size()); + } + + public String getName() + { + return names.stream().collect(Collectors.joining(".")); + } + + @JsonProperty + public Type getType() + { + return null; + } + + @Override + public int hashCode() + { + return Objects.hash(names); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + NestedColumn other = (NestedColumn) obj; + return Objects.equals(this.names, other.names); + } + + @Override + public String toString() + { + StringBuilder sb = new StringBuilder("NestedColumns{"); + sb.append("name='").append(getName()).append('\''); + sb.append('}'); + return sb.toString(); + } +} diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java index 49847968b98e..9a458df42641 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java @@ -14,6 +14,7 @@ package io.prestosql.spi.connector; import io.airlift.slice.Slice; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.security.GrantInfo; @@ -25,6 +26,7 @@ import io.prestosql.spi.statistics.TableStatisticsMetadata; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -556,4 +558,9 @@ default List listTablePrivileges(ConnectorSession session, SchemaTabl { return emptyList(); } + + default Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection dereferences) + { + return new HashMap<>(); + } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index 4b7ee5bb2bc6..5220c875dee9 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -14,6 +14,7 @@ package io.prestosql.spi.connector.classloader; import io.airlift.slice.Slice; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.classloader.ThreadContextClassLoader; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; @@ -548,4 +549,12 @@ public List listTablePrivileges(ConnectorSession session, SchemaTable return delegate.listTablePrivileges(session, prefix); } } + + @Override + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection dereferences) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getNestedColumnHandles(session, tableHandle, dereferences); + } + } }