diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java index a110a3d5f2ea..1d238cc056d4 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveColumnHandle.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; import io.prestosql.spi.type.TypeManager; @@ -60,6 +61,7 @@ public enum ColumnType private final int hiveColumnIndex; private final ColumnType columnType; private final Optional comment; + private final Optional nestedColumn; @JsonCreator public HiveColumnHandle( @@ -68,7 +70,8 @@ public HiveColumnHandle( @JsonProperty("typeSignature") TypeSignature typeSignature, @JsonProperty("hiveColumnIndex") int hiveColumnIndex, @JsonProperty("columnType") ColumnType columnType, - @JsonProperty("comment") Optional comment) + @JsonProperty("comment") Optional comment, + @JsonProperty("nestedColumn") Optional nestedColumn) { this.name = requireNonNull(name, "name is null"); checkArgument(hiveColumnIndex >= 0 || columnType == PARTITION_KEY || columnType == SYNTHESIZED, "hiveColumnIndex is negative"); @@ -77,6 +80,7 @@ public HiveColumnHandle( this.typeName = requireNonNull(typeSignature, "type is null"); this.columnType = requireNonNull(columnType, "columnType is null"); this.comment = requireNonNull(comment, "comment is null"); + this.nestedColumn = requireNonNull(nestedColumn, "nestedColumn is null"); } @JsonProperty @@ -97,6 +101,12 @@ public int getHiveColumnIndex() return hiveColumnIndex; } + @JsonProperty + public Optional getNestedColumn() + { + return nestedColumn; + } + public boolean isPartitionKey() { return columnType == PARTITION_KEY; @@ -133,7 +143,7 @@ public ColumnType getColumnType() @Override public int hashCode() { - return Objects.hash(name, hiveColumnIndex, hiveType, columnType, comment); + return Objects.hash(name, hiveColumnIndex, hiveType, columnType, comment, nestedColumn); } @Override @@ -150,7 +160,8 @@ public boolean equals(Object obj) Objects.equals(this.hiveColumnIndex, other.hiveColumnIndex) && Objects.equals(this.hiveType, other.hiveType) && Objects.equals(this.columnType, other.columnType) && - Objects.equals(this.comment, other.comment); + Objects.equals(this.comment, other.comment) && + Objects.equals(this.nestedColumn, other.nestedColumn); } @Override @@ -167,12 +178,12 @@ public static HiveColumnHandle updateRowIdHandle() // plan-time support for row-by-row delete so that planning doesn't fail. This is why we need // rowid handle. Note that in Hive connector, rowid handle is not implemented beyond plan-time. - return new HiveColumnHandle(UPDATE_ROW_ID_COLUMN_NAME, HIVE_LONG, BIGINT.getTypeSignature(), -1, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(UPDATE_ROW_ID_COLUMN_NAME, HIVE_LONG, BIGINT.getTypeSignature(), -1, SYNTHESIZED, Optional.empty(), Optional.empty()); } public static HiveColumnHandle pathColumnHandle() { - return new HiveColumnHandle(PATH_COLUMN_NAME, PATH_HIVE_TYPE, PATH_TYPE_SIGNATURE, PATH_COLUMN_INDEX, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(PATH_COLUMN_NAME, PATH_HIVE_TYPE, PATH_TYPE_SIGNATURE, PATH_COLUMN_INDEX, SYNTHESIZED, Optional.empty(), Optional.empty()); } /** @@ -182,7 +193,7 @@ public static HiveColumnHandle pathColumnHandle() */ public static HiveColumnHandle bucketColumnHandle() { - return new HiveColumnHandle(BUCKET_COLUMN_NAME, BUCKET_HIVE_TYPE, BUCKET_TYPE_SIGNATURE, BUCKET_COLUMN_INDEX, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(BUCKET_COLUMN_NAME, BUCKET_HIVE_TYPE, BUCKET_TYPE_SIGNATURE, BUCKET_COLUMN_INDEX, SYNTHESIZED, Optional.empty(), Optional.empty()); } public static boolean isPathColumnHandle(HiveColumnHandle column) diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java index e4cd12e12665..99d990b65e82 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java @@ -15,6 +15,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.base.Suppliers; import com.google.common.base.Verify; @@ -45,6 +46,7 @@ import io.prestosql.plugin.hive.metastore.Table; import io.prestosql.plugin.hive.metastore.thrift.ThriftMetastoreUtil; import io.prestosql.plugin.hive.statistics.HiveStatisticsProvider; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.StandardErrorCode; import io.prestosql.spi.block.Block; @@ -176,6 +178,7 @@ import static io.prestosql.plugin.hive.HiveUtil.decodeViewData; import static io.prestosql.plugin.hive.HiveUtil.encodeViewData; import static io.prestosql.plugin.hive.HiveUtil.getPartitionKeyColumnHandles; +import static io.prestosql.plugin.hive.HiveUtil.getRegularColumnHandles; import static io.prestosql.plugin.hive.HiveUtil.hiveColumnHandles; import static io.prestosql.plugin.hive.HiveUtil.schemaTableName; import static io.prestosql.plugin.hive.HiveUtil.toPartitionValues; @@ -633,6 +636,39 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((HiveColumnHandle) columnHandle).getColumnMetadata(typeManager); } + @Override + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection nestedColumns) + { + if (!HiveSessionProperties.isStatisticsEnabled(session)) { + return ImmutableMap.of(); + } + + SchemaTableName tableName = schemaTableName(tableHandle); + Optional table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); + + if (!table.isPresent()) { + throw new TableNotFoundException(tableName); + } + + // Only pushdown nested column for parquet table for now + if (!extractHiveStorageFormat(table.get()).equals(HiveStorageFormat.PARQUET)) { + return ImmutableMap.of(); + } + + List regularColumnHandles = getRegularColumnHandles(table.get()); + Map regularHiveColumnHandles = regularColumnHandles.stream().collect(Collectors.toMap(HiveColumnHandle::getName, identity())); + ImmutableMap.Builder columnHandles = ImmutableMap.builder(); + for (NestedColumn nestedColumn : nestedColumns) { + HiveColumnHandle hiveColumnHandle = regularHiveColumnHandles.get(nestedColumn.getBase()); + Optional childType = hiveColumnHandle.getHiveType().findChildType(nestedColumn); + Preconditions.checkArgument(childType.isPresent(), "%s doesn't exist in parent type %s", nestedColumn, hiveColumnHandle.getHiveType()); + if (hiveColumnHandle != null) { + columnHandles.put(nestedColumn, new HiveColumnHandle(nestedColumn.getName(), childType.get(), childType.get().getTypeSignature(), hiveColumnHandle.getHiveColumnIndex(), hiveColumnHandle.getColumnType(), hiveColumnHandle.getComment(), Optional.of(nestedColumn))); + } + } + return columnHandles.build(); + } + @Override public void createSchema(ConnectorSession session, String schemaName, Map properties) { @@ -2148,7 +2184,8 @@ else if (column.isHidden()) { column.getType().getTypeSignature(), ordinal, columnType, - Optional.ofNullable(column.getComment()))); + Optional.ofNullable(column.getComment()), + Optional.empty())); ordinal++; } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java index cec4185d1e9c..3c2ad68af1f0 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HivePageSourceProvider.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSet; import io.prestosql.plugin.hive.HdfsEnvironment.HdfsContext; import io.prestosql.plugin.hive.HiveSplit.BucketConversion; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ConnectorPageSource; import io.prestosql.spi.connector.ConnectorPageSourceProvider; @@ -320,8 +321,13 @@ public static List buildColumnMappings( for (HiveColumnHandle column : columns) { Optional coercionFrom = Optional.ofNullable(columnCoercions.get(column.getHiveColumnIndex())); if (column.getColumnType() == REGULAR) { - checkArgument(regularColumnIndices.add(column.getHiveColumnIndex()), "duplicate hiveColumnIndex in columns list"); - columnMappings.add(regular(column, regularIndex, coercionFrom)); + if (column.getNestedColumn().isPresent()) { + columnMappings.add(regular(column, regularIndex, getHiveType(coercionFrom, column.getNestedColumn().get()))); + } + else { + checkArgument(regularColumnIndices.add(column.getHiveColumnIndex()), "duplicate hiveColumnIndex in columns list"); + columnMappings.add(regular(column, regularIndex, coercionFrom)); + } regularIndex++; } else { @@ -345,6 +351,11 @@ public static List buildColumnMappings( return columnMappings.build(); } + private static Optional getHiveType(Optional baseType, NestedColumn nestedColumn) + { + return baseType.flatMap(type -> type.findChildType(nestedColumn)); + } + public static List extractRegularAndInterimColumnMappings(List columnMappings) { return columnMappings.stream() @@ -366,7 +377,8 @@ public static List toColumnHandles(List regular columnMapping.getCoercionFrom().get().getTypeSignature(), columnHandle.getHiveColumnIndex(), columnHandle.getColumnType(), - Optional.empty()); + Optional.empty(), + columnHandle.getNestedColumn()); }) .collect(toList()); } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java index 87d85731018e..75f085811352 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveType.java @@ -15,7 +15,9 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.type.NamedTypeSignature; import io.prestosql.spi.type.RowFieldName; @@ -175,6 +177,22 @@ public static boolean isSupportedType(TypeInfo typeInfo) return false; } + public Optional findChildType(NestedColumn nestedColumn) + { + TypeInfo typeInfo = getTypeInfo(); + for (String part : nestedColumn.getRest()) { + Preconditions.checkArgument(typeInfo instanceof StructTypeInfo, "typeinfo is not struct type", typeInfo); + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + try { + typeInfo = structTypeInfo.getStructFieldTypeInfo(part); + } + catch (RuntimeException e) { + return Optional.empty(); + } + } + return Optional.of(toHiveType(typeInfo)); + } + @JsonCreator public static HiveType valueOf(String hiveTypeName) { diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java index 75c4292af0b1..1818e74c2342 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveUtil.java @@ -825,7 +825,7 @@ public static List getRegularColumnHandles(Table table) // ignore unsupported types rather than failing HiveType hiveType = field.getType(); if (hiveType.isSupportedType()) { - columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), hiveColumnIndex, REGULAR, field.getComment())); + columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), hiveColumnIndex, REGULAR, field.getComment(), Optional.empty())); } hiveColumnIndex++; } @@ -843,7 +843,7 @@ public static List getPartitionKeyColumnHandles(Table table) if (!hiveType.isSupportedType()) { throw new PrestoException(NOT_SUPPORTED, format("Unsupported Hive type %s found in partition keys of table %s.%s", hiveType, table.getDatabaseName(), table.getTableName())); } - columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), -1, PARTITION_KEY, field.getComment())); + columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), -1, PARTITION_KEY, field.getComment(), Optional.empty())); } return columns.build(); diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java index 050e50ae5b79..954c6b645955 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java @@ -260,7 +260,7 @@ private static List getPhysicalHiveColumnHandles(List columnNames; private final List types; private final List> fields; + private final List> nestedColumns; private final Block[] constantBlocks; private final int[] hiveColumnIndexes; @@ -86,6 +92,7 @@ public ParquetPageSource( int size = columns.size(); this.constantBlocks = new Block[size]; this.hiveColumnIndexes = new int[size]; + this.nestedColumns = new ArrayList<>(size); ImmutableList.Builder namesBuilder = ImmutableList.builder(); ImmutableList.Builder typesBuilder = ImmutableList.builder(); @@ -99,15 +106,22 @@ public ParquetPageSource( namesBuilder.add(name); typesBuilder.add(type); + nestedColumns.add(column.getNestedColumn()); hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex(); - if (getParquetType(column, fileSchema, useParquetColumnNames) == null) { + if (getColumnType(column, fileSchema, useParquetColumnNames) == null) { constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_VECTOR_LENGTH); fieldsBuilder.add(Optional.empty()); } else { - String columnName = useParquetColumnNames ? name : fileSchema.getFields().get(column.getHiveColumnIndex()).getName(); - fieldsBuilder.add(constructField(type, lookupColumnByName(messageColumnIO, columnName))); + if (column.getNestedColumn().isPresent()) { + NestedColumn nestedColumn = column.getNestedColumn().get(); + fieldsBuilder.add(constructField(getNestedStructType(nestedColumn, type), lookupColumnByName(messageColumnIO, nestedColumn.getBase()))); + } + else { + String columnName = useParquetColumnNames ? name : fileSchema.getFields().get(column.getHiveColumnIndex()).getName(); + fieldsBuilder.add(constructField(type, lookupColumnByName(messageColumnIO, columnName))); + } } } types = typesBuilder.build(); @@ -157,19 +171,17 @@ public Page getNextPage() blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize); } else { - Type type = types.get(fieldId); Optional field = fields.get(fieldId); - int fieldIndex; - if (useParquetColumnNames) { - fieldIndex = getFieldIndex(fileSchema, columnNames.get(fieldId)); - } - else { - fieldIndex = hiveColumnIndexes[fieldId]; - } - if (fieldIndex != -1 && field.isPresent()) { - blocks[fieldId] = new LazyBlock(batchSize, new ParquetBlockLoader(field.get())); + if (field.isPresent()) { + if (nestedColumns.get(fieldId).isPresent()) { + blocks[fieldId] = new LazyBlock(batchSize, new NestedColumnParquetBlockLoader(field.get(), types.get(fieldId))); + } + else { + blocks[fieldId] = new LazyBlock(batchSize, new ParquetBlockLoader(field.get())); + } } else { + Type type = types.get(fieldId); blocks[fieldId] = RunLengthEncodedBlock.create(type, null, batchSize); } } @@ -250,4 +262,93 @@ public final void load(LazyBlock lazyBlock) loaded = true; } } + + private final class NestedColumnParquetBlockLoader + implements LazyBlockLoader + { + private final int expectedBatchId = batchId; + private final Field field; + private final Type type; + private final int level; + private boolean loaded; + + // field is group field + public NestedColumnParquetBlockLoader(Field field, Type type) + { + this.field = requireNonNull(field, "field is null"); + this.type = requireNonNull(type, "type is null"); + this.level = getLevel(field.getType(), type); + } + + int getLevel(Type rootType, Type leafType) + { + int level = 0; + Type currentType = rootType; + while (!currentType.equals(leafType)) { + currentType = currentType.getTypeParameters().get(0); + ++level; + } + return level; + } + + @Override + public final void load(LazyBlock lazyBlock) + { + if (loaded) { + return; + } + + checkState(batchId == expectedBatchId); + + try { + Block block = parquetReader.readBlock(field); + + int size = block.getPositionCount(); + boolean[] isNulls = new boolean[size]; + + for (int currentLevel = 0; currentLevel < level; ++currentLevel) { + ColumnarRow rowBlock = ColumnarRow.toColumnarRow(block); + int index = 0; + for (int j = 0; j < size; ++j) { + if (!isNulls[j]) { + isNulls[j] = rowBlock.isNull(index); + ++index; + } + } + block = rowBlock.getField(0); + } + + BlockBuilder blockBuilder = type.createBlockBuilder(null, size); + int currentPosition = 0; + for (int i = 0; i < size; ++i) { + if (isNulls[i]) { + blockBuilder.appendNull(); + } + else { + Preconditions.checkArgument(currentPosition < block.getPositionCount(), "current position cannot exceed total position count"); + type.appendTo(block, currentPosition, blockBuilder); + currentPosition++; + } + } + lazyBlock.setBlock(blockBuilder.build()); + } + catch (ParquetCorruptionException e) { + throw new PrestoException(HIVE_BAD_DATA, e); + } + catch (IOException e) { + throw new PrestoException(HIVE_CURSOR_ERROR, e); + } + loaded = true; + } + } + + private Type getNestedStructType(NestedColumn nestedColumn, Type leafType) + { + Type type = leafType; + List names = nestedColumn.getRest(); + for (int i = names.size() - 1; i >= 0; --i) { + type = RowType.from(ImmutableList.of(RowType.field(names.get(i), type))); + } + return type; + } } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java index 27c854e92e80..eeefc4a98d8e 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -63,6 +63,7 @@ import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.parquet.ParquetTypeUtils.getColumnIO; import static io.prestosql.parquet.ParquetTypeUtils.getDescriptors; +import static io.prestosql.parquet.ParquetTypeUtils.getNestedColumnType; import static io.prestosql.parquet.ParquetTypeUtils.getParquetTypeByName; import static io.prestosql.parquet.predicate.PredicateUtils.buildPredicate; import static io.prestosql.parquet.predicate.PredicateUtils.predicateMatches; @@ -77,7 +78,6 @@ import static io.prestosql.plugin.hive.parquet.HdfsParquetDataSource.buildHdfsParquetDataSource; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.PRIMITIVE; public class ParquetPageSourceFactory @@ -163,13 +163,14 @@ public static ParquetPageSource createParquetPageSource( MessageType fileSchema = fileMetaData.getSchema(); dataSource = buildHdfsParquetDataSource(inputStream, path, fileSize, stats); - List fields = columns.stream() + Optional optionalRequestedSchema = columns.stream() .filter(column -> column.getColumnType() == REGULAR) - .map(column -> getParquetType(column, fileSchema, useParquetColumnNames)) + .map(column -> getColumnType(column, fileSchema, useParquetColumnNames)) .filter(Objects::nonNull) - .collect(toList()); + .map(type -> new MessageType(fileSchema.getName(), type)) + .reduce(MessageType::union); - MessageType requestedSchema = new MessageType(fileSchema.getName(), fields); + MessageType requestedSchema = optionalRequestedSchema.orElseGet(() -> new MessageType(fileSchema.getName(), ImmutableList.of())); ImmutableList.Builder footerBlocks = ImmutableList.builder(); for (BlockMetaData block : parquetMetadata.getBlocks()) { @@ -266,4 +267,12 @@ public static org.apache.parquet.schema.Type getParquetType(HiveColumnHandle col } return null; } + + public static org.apache.parquet.schema.Type getColumnType(HiveColumnHandle column, MessageType messageType, boolean useParquetColumnNames) + { + if (useParquetColumnNames && column.getNestedColumn().isPresent()) { + return getNestedColumnType(messageType, column.getNestedColumn().get()); + } + return getParquetType(column, messageType, useParquetColumnNames); + } } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java index 2beeb7022ed7..6bac1e746886 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHive.java @@ -624,11 +624,11 @@ protected void setupHive(String databaseName, String timeZoneId) Optional.empty(), Optional.empty()); - dsColumn = new HiveColumnHandle("ds", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty()); - fileFormatColumn = new HiveColumnHandle("file_format", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty()); - dummyColumn = new HiveColumnHandle("dummy", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty()); - intColumn = new HiveColumnHandle("t_int", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty()); - invalidColumnHandle = new HiveColumnHandle(INVALID_COLUMN, HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 0, REGULAR, Optional.empty()); + dsColumn = new HiveColumnHandle("ds", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + fileFormatColumn = new HiveColumnHandle("file_format", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + dummyColumn = new HiveColumnHandle("dummy", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + intColumn = new HiveColumnHandle("t_int", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), -1, PARTITION_KEY, Optional.empty(), Optional.empty()); + invalidColumnHandle = new HiveColumnHandle(INVALID_COLUMN, HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 0, REGULAR, Optional.empty(), Optional.empty()); List partitionColumns = ImmutableList.of(dsColumn, fileFormatColumn, dummyColumn); List partitions = ImmutableList.builder() diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java index d80669152a75..b80072cbe001 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/AbstractTestHiveFileFormats.java @@ -483,7 +483,7 @@ protected List getColumnHandles(List testColumns) int columnIndex = testColumn.isPartitionKey() ? -1 : nextHiveColumnIndex++; HiveType hiveType = HiveType.valueOf(testColumn.getObjectInspector().getTypeName()); - columns.add(new HiveColumnHandle(testColumn.getName(), hiveType, hiveType.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty())); + columns.add(new HiveColumnHandle(testColumn.getName(), hiveType, hiveType.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty(), Optional.empty())); } return columns; } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java index 34242455824f..f9d33d0b5d4c 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestBackgroundHiveSplitLoader.java @@ -100,7 +100,7 @@ public class TestBackgroundHiveSplitLoader private static final List PARTITION_COLUMNS = ImmutableList.of( new Column("partitionColumn", HIVE_INT, Optional.empty())); private static final List BUCKET_COLUMN_HANDLES = ImmutableList.of( - new HiveColumnHandle("col1", HIVE_INT, INTEGER.getTypeSignature(), 0, ColumnType.REGULAR, Optional.empty())); + new HiveColumnHandle("col1", HIVE_INT, INTEGER.getTypeSignature(), 0, ColumnType.REGULAR, Optional.empty(), Optional.empty())); private static final Optional BUCKET_PROPERTY = Optional.of( new HiveBucketProperty(ImmutableList.of("col1"), BUCKET_COUNT, ImmutableList.of())); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java index f7f097d422dc..867bef1187f9 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveColumnHandle.java @@ -38,14 +38,14 @@ public void testHiddenColumn() @Test public void testRegularColumn() { - HiveColumnHandle expectedPartitionColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty()); + HiveColumnHandle expectedPartitionColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty(), Optional.empty()); testRoundTrip(expectedPartitionColumn); } @Test public void testPartitionKeyColumn() { - HiveColumnHandle expectedRegularColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, REGULAR, Optional.empty()); + HiveColumnHandle expectedRegularColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, REGULAR, Optional.empty(), Optional.empty()); testRoundTrip(expectedRegularColumn); } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java index 8d7c515868fd..c3a9fb906224 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java @@ -34,6 +34,7 @@ public class TestHiveMetadata TypeSignature.parseTypeSignature("varchar"), 0, HiveColumnHandle.ColumnType.PARTITION_KEY, + Optional.empty(), Optional.empty()); @Test(timeOut = 5000) diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java index 834bfb4c2d0d..016355b26096 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHivePageSink.java @@ -283,7 +283,7 @@ private static List getColumnHandles() for (int i = 0; i < columns.size(); i++) { LineItemColumn column = columns.get(i); HiveType hiveType = getHiveType(column.getType()); - handles.add(new HiveColumnHandle(column.getColumnName(), hiveType, hiveType.getTypeSignature(), i, REGULAR, Optional.empty())); + handles.add(new HiveColumnHandle(column.getColumnName(), hiveType, hiveType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); } return handles.build(); } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java index fa01cdc3ffa7..891e23570104 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveSplit.java @@ -61,7 +61,7 @@ public void testJsonRoundTrip() Optional.of(new HiveSplit.BucketConversion( 32, 16, - ImmutableList.of(new HiveColumnHandle("col", HIVE_LONG, BIGINT.getTypeSignature(), 5, ColumnType.REGULAR, Optional.of("comment"))))), + ImmutableList.of(new HiveColumnHandle("col", HIVE_LONG, BIGINT.getTypeSignature(), 5, ColumnType.REGULAR, Optional.of("comment"), Optional.empty())))), false); String json = codec.toJson(expected); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java index a38b0392024d..94987d922bc5 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestIonSqlQueryBuilder.java @@ -56,9 +56,9 @@ public void testBuildSQL() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("n_nationkey", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("n_name", HIVE_STRING, parseTypeSignature(VARCHAR), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("n_regionkey", HIVE_INT, parseTypeSignature(INTEGER), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("n_nationkey", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("n_name", HIVE_STRING, parseTypeSignature(VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("n_regionkey", HIVE_INT, parseTypeSignature(INTEGER), 2, REGULAR, Optional.empty(), Optional.empty())); assertEquals("SELECT s._1, s._2, s._3 FROM S3Object s", queryBuilder.buildSql(columns, TupleDomain.all())); @@ -81,9 +81,9 @@ public void testDecimalColumns() TypeManager typeManager = new TypeRegistry(); IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(typeManager); List columns = ImmutableList.of( - new HiveColumnHandle("quantity", HiveType.valueOf("decimal(20,0)"), parseTypeSignature(DECIMAL), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("extendedprice", HiveType.valueOf("decimal(20,2)"), parseTypeSignature(DECIMAL), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("discount", HiveType.valueOf("decimal(10,2)"), parseTypeSignature(DECIMAL), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("quantity", HiveType.valueOf("decimal(20,0)"), parseTypeSignature(DECIMAL), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("extendedprice", HiveType.valueOf("decimal(20,2)"), parseTypeSignature(DECIMAL), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("discount", HiveType.valueOf("decimal(10,2)"), parseTypeSignature(DECIMAL), 2, REGULAR, Optional.empty(), Optional.empty())); DecimalType decimalType = DecimalType.createDecimalType(10, 2); TupleDomain tupleDomain = withColumnDomains( ImmutableMap.of( @@ -101,8 +101,8 @@ public void testDateColumn() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("t1", HIVE_TIMESTAMP, parseTypeSignature(TIMESTAMP), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("t2", HIVE_DATE, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty())); + new HiveColumnHandle("t1", HIVE_TIMESTAMP, parseTypeSignature(TIMESTAMP), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("t2", HIVE_DATE, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty(), Optional.empty())); TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of( columns.get(1), Domain.create(SortedRangeSet.copyOf(DATE, ImmutableList.of(Range.equal(DATE, (long) DateTimeUtils.parseDate("2001-08-22")))), false))); @@ -114,9 +114,9 @@ public void testNotPushDoublePredicates() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("extendedprice", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("discount", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("extendedprice", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("discount", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 2, REGULAR, Optional.empty(), Optional.empty())); TupleDomain tupleDomain = withColumnDomains( ImmutableMap.of( columns.get(0), Domain.create(ofRanges(Range.lessThan(BIGINT, 50L)), false), diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java index bdb3ee55513b..50963432674b 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestJsonHiveHandles.java @@ -77,7 +77,7 @@ public void testTableHandleDeserialize() public void testColumnHandleSerialize() throws Exception { - HiveColumnHandle columnHandle = new HiveColumnHandle("column", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), -1, PARTITION_KEY, Optional.of("comment")); + HiveColumnHandle columnHandle = new HiveColumnHandle("column", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), -1, PARTITION_KEY, Optional.of("comment"), Optional.empty()); assertTrue(objectMapper.canSerialize(HiveColumnHandle.class)); String json = objectMapper.writeValueAsString(columnHandle); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java index dbf49cf79fac..80e211a3538e 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestOrcPageSourceMemoryTracking.java @@ -445,8 +445,7 @@ public TestPreparer(String tempFilePath, List testColumns, int numRo ObjectInspector inspector = testColumn.getObjectInspector(); HiveType hiveType = HiveType.valueOf(inspector.getTypeName()); Type type = hiveType.getType(TYPE_MANAGER); - - columnsBuilder.add(new HiveColumnHandle(testColumn.getName(), hiveType, type.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty())); + columnsBuilder.add(new HiveColumnHandle(testColumn.getName(), hiveType, type.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty(), Optional.empty())); typesBuilder.add(type); } columns = columnsBuilder.build(); diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java index 02f6b2d83064..57d5f539dbfd 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestS3SelectRecordCursor.java @@ -15,21 +15,28 @@ import com.google.common.collect.ImmutableList; import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.TestingTypeManager; +import io.prestosql.spi.type.TypeManager; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.mapred.RecordReader; +import org.joda.time.DateTimeZone; import org.testng.annotations.Test; import java.util.Optional; import java.util.Properties; import java.util.stream.Stream; +import static io.prestosql.plugin.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static io.prestosql.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; import static io.prestosql.plugin.hive.HiveType.HIVE_INT; import static io.prestosql.plugin.hive.HiveType.HIVE_STRING; import static io.prestosql.plugin.hive.S3SelectRecordCursor.updateSplitSchema; import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; import static java.util.stream.Collectors.joining; import static org.apache.hadoop.hive.serde.serdeConstants.LIST_COLUMNS; import static org.apache.hadoop.hive.serde.serdeConstants.LIST_COLUMN_TYPES; @@ -41,11 +48,42 @@ public class TestS3SelectRecordCursor { private static final String LAZY_SERDE_CLASS_NAME = LazySimpleSerDe.class.getName(); - private static final HiveColumnHandle ARTICLE_COLUMN = new HiveColumnHandle("article", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle AUTHOR_COLUMN = new HiveColumnHandle("author", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle DATE_ARTICLE_COLUMN = new HiveColumnHandle("date_pub", HIVE_INT, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle QUANTITY_COLUMN = new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), 1, REGULAR, Optional.empty()); + private static final HiveColumnHandle ARTICLE_COLUMN = new HiveColumnHandle("article", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle AUTHOR_COLUMN = new HiveColumnHandle("author", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle DATE_ARTICLE_COLUMN = new HiveColumnHandle("date_pub", HIVE_INT, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle QUANTITY_COLUMN = new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), 1, REGULAR, Optional.empty(), Optional.empty()); private static final HiveColumnHandle[] DEFAULT_TEST_COLUMNS = {ARTICLE_COLUMN, AUTHOR_COLUMN, DATE_ARTICLE_COLUMN, QUANTITY_COLUMN}; + private static final HiveColumnHandle MOCK_HIVE_COLUMN_HANDLE = new HiveColumnHandle("mockName", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty(), Optional.empty()); + private static final TypeManager MOCK_TYPE_MANAGER = new TestingTypeManager(); + private static final Path MOCK_PATH = new Path("mockPath"); + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "splitSchema is null") + public void shouldFailOnNullSplitSchema() + { + new S3SelectRecordCursor( + new Configuration(), + MOCK_PATH, + MOCK_RECORD_READER, + 100L, + null, + singletonList(MOCK_HIVE_COLUMN_HANDLE), + DateTimeZone.UTC, + MOCK_TYPE_MANAGER); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "columns is null") + public void shouldFailOnNullColumns() + { + new S3SelectRecordCursor( + new Configuration(), + MOCK_PATH, + MOCK_RECORD_READER, + 100L, + new Properties(), + null, + DateTimeZone.UTC, + MOCK_TYPE_MANAGER); + } @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Invalid Thrift DDL struct article \\{ \\}") public void shouldThrowIllegalArgumentExceptionWhenSerialDDLHasNoColumns() diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java index 3d58b4e514ff..6e2294ce7dff 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/benchmark/FileFormat.java @@ -13,6 +13,7 @@ */ package io.prestosql.plugin.hive.benchmark; +import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; import io.airlift.slice.OutputStreamSliceOutput; import io.prestosql.orc.OrcWriter; @@ -42,6 +43,7 @@ import io.prestosql.rcfile.RcFileWriter; import io.prestosql.rcfile.binary.BinaryRcFileEncoding; import io.prestosql.rcfile.text.TextRcFileEncoding; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.Page; import io.prestosql.spi.connector.ConnectorPageSource; import io.prestosql.spi.connector.ConnectorSession; @@ -293,7 +295,7 @@ private static ConnectorPageSource createPageSource( for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); Type columnType = columnTypes.get(i); - columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty())); + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); } RecordCursor recordCursor = cursorProvider @@ -327,7 +329,15 @@ private static ConnectorPageSource createPageSource( for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); Type columnType = columnTypes.get(i); - columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty())); + + if (columnName.contains(".")) { + Splitter splitter = Splitter.on('.').trimResults().omitEmptyStrings(); + List names = splitter.splitToList(columnName); + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.of(new NestedColumn(names)))); + } + else { + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); + } } return pageSourceFactory diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java index b01d7c30b106..8305f9c38d20 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/AbstractTestParquetReader.java @@ -18,9 +18,11 @@ import com.google.common.collect.ContiguousSet; import com.google.common.collect.DiscreteDomain; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import com.google.common.collect.Range; import com.google.common.primitives.Shorts; import io.airlift.units.DataSize; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.type.ArrayType; import io.prestosql.spi.type.RowType; import io.prestosql.spi.type.SqlDate; @@ -533,6 +535,33 @@ public void testNestedStructs() tester.testRoundTrip(objectInspector, values, values, type); } + @Test + public void testSingleNestedColumn() + throws Exception + { + int nestingLevel = ThreadLocalRandom.current().nextInt(1, 15); + Optional> structFieldNames = Optional.of(singletonList("structField")); + + Iterable values = createTestStructs(limit(cycle(asList(1, null, 3, null, 5, null, 7, null, null, null, 11, null, 13)), 3_210)); + Iterable readValues = Lists.newArrayList(values); + + ObjectInspector objectInspector = getStandardStructObjectInspector(structFieldNames.get(), singletonList(javaIntObjectInspector)); + Type type = RowType.from(singletonList(field("structField", INTEGER))); + + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add("root"); + + for (int i = 0; i < nestingLevel; i++) { + int n = ThreadLocalRandom.current().nextInt(2, 5); + values = insertNullEvery(n, createTestStructs(values)); + readValues = insertNullEvery(n, readValues); + objectInspector = getStandardStructObjectInspector(structFieldNames.get(), singletonList(objectInspector)); + builder.add("structField"); + } + tester.testNestedColumnRoundTrip(singletonList(objectInspector), new Iterable[] {values}, + new Iterable[] {readValues}, ImmutableList.of(new NestedColumn(builder.build())), singletonList(type), Optional.empty(), false); + } + @Test public void testComplexNestedStructs() throws Exception diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java index f52cac556077..885093655720 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/ParquetTester.java @@ -31,6 +31,7 @@ import io.prestosql.plugin.hive.parquet.write.SingleLevelArrayMapKeyValuesSchemaConverter; import io.prestosql.plugin.hive.parquet.write.SingleLevelArraySchemaConverter; import io.prestosql.plugin.hive.parquet.write.TestMapredParquetOutputFormat; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ConnectorPageSource; @@ -79,6 +80,7 @@ import java.util.Properties; import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; import static com.google.common.base.Functions.constant; import static com.google.common.base.Verify.verify; @@ -190,6 +192,19 @@ public void testRoundTrip(ObjectInspector objectInspector, Iterable writeValu } } + public void testNestedColumnRoundTrip(List objectInspectors, Iterable[] writeValues, Iterable[] readValues, List nestedColumns, List columnTypes, Optional parquetSchema, boolean singleLevelArray) + throws Exception + { + List columnNames = nestedColumns.stream().map(NestedColumn::getName).collect(Collectors.toList()); + List rootNames = nestedColumns.stream().map(NestedColumn::getBase).collect(Collectors.toList()); + + // just the values + testRoundTripType(objectInspectors, writeValues, readValues, columnNames, columnTypes, Optional.of(rootNames), parquetSchema, singleLevelArray); + + // all nulls + assertRoundTrip(objectInspectors, transformToNulls(writeValues), transformToNulls(readValues), columnNames, columnTypes, Optional.of(rootNames), parquetSchema, singleLevelArray); + } + public void testRoundTrip(ObjectInspector objectInspector, Iterable writeValues, Iterable readValues, Type type, Optional parquetSchema) throws Exception { @@ -247,18 +262,32 @@ private void testRoundTripType( Optional parquetSchema, boolean singleLevelArray) throws Exception + { + testRoundTripType(objectInspectors, writeValues, readValues, columnNames, columnTypes, Optional.empty(), parquetSchema, singleLevelArray); + } + + private void testRoundTripType( + List objectInspectors, + Iterable[] writeValues, + Iterable[] readValues, + List columnNames, + List columnTypes, + Optional> rootColumns, + Optional parquetSchema, + boolean singleLevelArray) + throws Exception { // forward order - assertRoundTrip(objectInspectors, writeValues, readValues, columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, writeValues, readValues, columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); // reverse order - assertRoundTrip(objectInspectors, reverse(writeValues), reverse(readValues), columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, reverse(writeValues), reverse(readValues), columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); // forward order with nulls - assertRoundTrip(objectInspectors, insertNullEvery(5, writeValues), insertNullEvery(5, readValues), columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, insertNullEvery(5, writeValues), insertNullEvery(5, readValues), columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); // reverse order with nulls - assertRoundTrip(objectInspectors, insertNullEvery(5, reverse(writeValues)), insertNullEvery(5, reverse(readValues)), columnNames, columnTypes, parquetSchema, singleLevelArray); + assertRoundTrip(objectInspectors, insertNullEvery(5, reverse(writeValues)), insertNullEvery(5, reverse(readValues)), columnNames, columnTypes, rootColumns, parquetSchema, singleLevelArray); } void assertRoundTrip( @@ -282,6 +311,19 @@ void assertRoundTrip( Optional parquetSchema, boolean singleLevelArray) throws Exception + { + assertRoundTrip(objectInspectors, writeValues, readValues, columnNames, columnTypes, Optional.empty(), parquetSchema, singleLevelArray); + } + + void assertRoundTrip(List objectInspectors, + Iterable[] writeValues, + Iterable[] readValues, + List columnNames, + List columnTypes, + Optional> rootColumns, + Optional parquetSchema, + boolean singleLevelArray) + throws Exception { for (WriterVersion version : versions) { for (CompressionCodecName compressionCodecName : compressions) { @@ -291,12 +333,13 @@ void assertRoundTrip( jobConf.setEnum(COMPRESSION, compressionCodecName); jobConf.setBoolean(ENABLE_DICTIONARY, true); jobConf.setEnum(WRITER_VERSION, version); + List writeColumns = rootColumns.isPresent() ? rootColumns.get() : columnNames; writeParquetColumn( jobConf, tempFile.getFile(), compressionCodecName, - createTableProperties(columnNames, objectInspectors), - getStandardStructObjectInspector(columnNames, objectInspectors), + createTableProperties(writeColumns, objectInspectors), + getStandardStructObjectInspector(writeColumns, objectInspectors), getIterators(writeValues), parquetSchema, singleLevelArray); @@ -605,7 +648,7 @@ private Iterable[] transformToNulls(Iterable[] values) private static Iterable[] reverse(Iterable[] iterables) { return stream(iterables) - .map(ImmutableList::copyOf) + .map(Lists::newArrayList) .map(Lists::reverse) .toArray(Iterable[]::new); } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java index be41c742a719..07661293462f 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -56,7 +56,7 @@ public class TestParquetPredicateUtils @Test public void testParquetTupleDomainPrimitiveArray() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_array", HiveType.valueOf("array"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_array", HiveType.valueOf("array"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty(), Optional.empty()); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(new ArrayType(INTEGER)))); MessageType fileSchema = new MessageType("hive_schema", @@ -71,7 +71,7 @@ public void testParquetTupleDomainPrimitiveArray() @Test public void testParquetTupleDomainStructArray() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_array_struct", HiveType.valueOf("array>"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_array_struct", HiveType.valueOf("array>"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty(), Optional.empty()); RowType.Field rowField = new RowType.Field(Optional.of("a"), INTEGER); RowType rowType = RowType.from(ImmutableList.of(rowField)); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(new ArrayType(rowType)))); @@ -89,7 +89,7 @@ public void testParquetTupleDomainStructArray() @Test public void testParquetTupleDomainPrimitive() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_primitive", HiveType.valueOf("bigint"), parseTypeSignature(StandardTypes.BIGINT), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_primitive", HiveType.valueOf("bigint"), parseTypeSignature(StandardTypes.BIGINT), 0, REGULAR, Optional.empty(), Optional.empty()); Domain singleValueDomain = Domain.singleValue(BIGINT, 123L); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, singleValueDomain)); @@ -110,7 +110,7 @@ public void testParquetTupleDomainPrimitive() @Test public void testParquetTupleDomainStruct() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_struct", HiveType.valueOf("struct"), parseTypeSignature(StandardTypes.ROW), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_struct", HiveType.valueOf("struct"), parseTypeSignature(StandardTypes.ROW), 0, REGULAR, Optional.empty(), Optional.empty()); RowType.Field rowField = new RowType.Field(Optional.of("my_struct"), INTEGER); RowType rowType = RowType.from(ImmutableList.of(rowField)); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(rowType))); @@ -127,7 +127,7 @@ public void testParquetTupleDomainStruct() @Test public void testParquetTupleDomainMap() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_map", HiveType.valueOf("map"), parseTypeSignature(StandardTypes.MAP), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_map", HiveType.valueOf("map"), parseTypeSignature(StandardTypes.MAP), 0, REGULAR, Optional.empty(), Optional.empty()); MapType mapType = new MapType( INTEGER, diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java index 812ca75ad20d..649e75460b8d 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java @@ -94,8 +94,8 @@ public class TestMetastoreHiveStatisticsProvider private static final String COLUMN = "column"; private static final DecimalType DECIMAL = createDecimalType(5, 3); - private static final HiveColumnHandle PARTITION_COLUMN_1 = new HiveColumnHandle("p1", HIVE_STRING, VARCHAR.getTypeSignature(), 0, PARTITION_KEY, Optional.empty()); - private static final HiveColumnHandle PARTITION_COLUMN_2 = new HiveColumnHandle("p2", HIVE_LONG, BIGINT.getTypeSignature(), 1, PARTITION_KEY, Optional.empty()); + private static final HiveColumnHandle PARTITION_COLUMN_1 = new HiveColumnHandle("p1", HIVE_STRING, VARCHAR.getTypeSignature(), 0, PARTITION_KEY, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle PARTITION_COLUMN_2 = new HiveColumnHandle("p2", HIVE_LONG, BIGINT.getTypeSignature(), 1, PARTITION_KEY, Optional.empty(), Optional.empty()); @Test public void testGetPartitionsSample() @@ -611,7 +611,7 @@ public void testGetTableStatistics() .build(); MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((table, hivePartitions) -> ImmutableMap.of(partitionName, statistics)); TestingConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(new HiveConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties()); - HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty(), Optional.empty()); TableStatistics expected = TableStatistics.builder() .setRowCount(Estimate.of(1000)) .setColumnStatistics( @@ -661,7 +661,7 @@ public void testGetTableStatisticsUnpartitioned() .build(); MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((table, hivePartitions) -> ImmutableMap.of(UNPARTITIONED_ID, statistics)); TestingConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(new HiveConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties()); - HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty(), Optional.empty()); TableStatistics expected = TableStatistics.builder() .setRowCount(Estimate.of(1000)) .setColumnStatistics( diff --git a/presto-main/etc/catalog/hive.properties b/presto-main/etc/catalog/hive.properties index 9060aff2a6a7..462c16ee537f 100644 --- a/presto-main/etc/catalog/hive.properties +++ b/presto-main/etc/catalog/hive.properties @@ -7,3 +7,5 @@ connector.name=hive-hadoop2 hive.metastore.uri=thrift://localhost:9083 + +hive.parquet.use-column-names=true diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index f4d61714728d..a33cb04f340b 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.prestosql.Session; import io.prestosql.connector.CatalogName; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.BlockEncodingSerde; import io.prestosql.spi.connector.CatalogSchemaName; @@ -121,6 +122,13 @@ public interface Metadata */ Map getColumnHandles(Session session, TableHandle tableHandle); + /** + * Gets all of the columns on the specified table, or an empty map if the columns can not be enumerated. + * + * @throws RuntimeException if table handle is no longer valid + */ + Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection dereferences); + /** * Gets the metadata for the specified table column. * diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index a14cd8509436..beb2a799bdbd 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -27,6 +27,7 @@ import io.prestosql.Session; import io.prestosql.block.BlockEncodingManager; import io.prestosql.connector.CatalogName; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.QueryId; import io.prestosql.spi.block.BlockEncodingSerde; @@ -522,6 +523,14 @@ public Map getColumnHandles(Session session, TableHandle t return map.build(); } + @Override + public Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection dereferences) + { + CatalogName catalogName = tableHandle.getCatalogName(); + ConnectorMetadata metadata = getMetadata(session, catalogName); + return metadata.getNestedColumnHandles(session.toConnectorSession(catalogName), tableHandle.getConnectorHandle(), dereferences); + } + @Override public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 5d19e08e41bc..6de8269db8ae 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -118,12 +118,14 @@ import io.prestosql.sql.planner.optimizations.ImplementIntersectAndExceptAsUnion; import io.prestosql.sql.planner.optimizations.IndexJoinOptimizer; import io.prestosql.sql.planner.optimizations.LimitPushDown; +import io.prestosql.sql.planner.optimizations.MergeNestedColumn; import io.prestosql.sql.planner.optimizations.MetadataDeleteOptimizer; import io.prestosql.sql.planner.optimizations.MetadataQueryOptimizer; import io.prestosql.sql.planner.optimizations.OptimizeMixedDistinctAggregations; import io.prestosql.sql.planner.optimizations.PlanOptimizer; import io.prestosql.sql.planner.optimizations.PredicatePushDown; import io.prestosql.sql.planner.optimizations.PruneUnreferencedOutputs; +import io.prestosql.sql.planner.optimizations.PushDownDereferenceExpression; import io.prestosql.sql.planner.optimizations.ReplicateSemiJoinInDelete; import io.prestosql.sql.planner.optimizations.SetFlatteningOptimizer; import io.prestosql.sql.planner.optimizations.StatsRecordingPlanOptimizer; @@ -359,6 +361,13 @@ public PlanOptimizers( new TransformCorrelatedSingleRowSubqueryToProject(), new RemoveAggregationInSemiJoin())), new CheckSubqueryNodesAreRewritten(), + + // pushdown dereference + new PushDownDereferenceExpression(metadata, typeAnalyzer), + new PruneUnreferencedOutputs(), + new MergeNestedColumn(metadata, typeAnalyzer), + new IterativeOptimizer(ruleStats, statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PruneTableScanColumns())), + predicatePushDown, new IterativeOptimizer( ruleStats, @@ -408,6 +417,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new EliminateCrossJoins())), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again + predicatePushDown, simplifyOptimizer, // Should be always run after PredicatePushDown new IterativeOptimizer( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index 5ae6bad95514..c5f85fc30380 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -156,6 +156,7 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN .filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node .filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics .filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever + //.filter(entry -> !(child.getAssignments().get(entry.getKey()) instanceof DereferenceExpression)) // skip dereference expression .map(Map.Entry::getKey) .collect(toSet()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MergeNestedColumn.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MergeNestedColumn.java new file mode 100644 index 000000000000..3510c04db1a4 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/MergeNestedColumn.java @@ -0,0 +1,317 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.optimizations; + +import com.google.common.base.Preconditions; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.Session; +import io.prestosql.execution.warnings.WarningCollector; +import io.prestosql.metadata.Metadata; +import io.prestosql.metadata.TableHandle; +import io.prestosql.spi.NestedColumn; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SimplePlanRewriter; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.ExpressionRewriter; +import io.prestosql.sql.tree.ExpressionTreeRewriter; +import io.prestosql.sql.tree.Identifier; +import io.prestosql.sql.tree.SubscriptExpression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class MergeNestedColumn + implements PlanOptimizer +{ + private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; + + public MergeNestedColumn(Metadata metadata, TypeAnalyzer typeAnalyzer) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + return SimplePlanRewriter.rewriteWith(new Optimizer(session, symbolAllocator, idAllocator, metadata, typeAnalyzer, warningCollector), plan); + } + + public static boolean prefixExist(Expression expression, final Set allDereferences) + { + int[] referenceCount = {0}; + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, int[] referenceCount) + { + if (allDereferences.contains(node)) { + referenceCount[0] += 1; + } + process(node.getBase(), referenceCount); + return null; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, int[] context) + { + if (allDereferences.contains(node)) { + referenceCount[0] += 1; + } + return null; + } + }.process(expression, referenceCount); + return referenceCount[0] > 1; + } + + private static class Optimizer + extends SimplePlanRewriter + { + private final Session session; + private final SymbolAllocator symbolAllocator; + private final PlanNodeIdAllocator idAllocator; + private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; + + public Optimizer(Session session, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, TypeAnalyzer typeAnalyzer, WarningCollector warningCollector) + { + this.session = requireNonNull(session); + this.symbolAllocator = requireNonNull(symbolAllocator); + this.idAllocator = requireNonNull(idAllocator); + this.metadata = requireNonNull(metadata); + this.typeAnalyzer = requireNonNull(typeAnalyzer); + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext context) + { + if (node.getSource() instanceof TableScanNode) { + TableScanNode tableScanNode = (TableScanNode) node.getSource(); + return mergeProjectWithTableScan(node, tableScanNode, context); + } + return context.defaultRewrite(node); + } + + private Type extractType(Expression expression) + { + Type type = typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression); + return type; + } + + public PlanNode mergeProjectWithTableScan(ProjectNode node, TableScanNode tableScanNode, RewriteContext context) + { + Set allExpressions = node.getAssignments().getExpressions().stream().map(MergeNestedColumn::validDereferenceExpression).filter(Objects::nonNull).collect(toImmutableSet()); + Set dereferences = allExpressions.stream() + .filter(expression -> !prefixExist(expression, allExpressions)) + .filter(expression -> expression instanceof DereferenceExpression) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return context.defaultRewrite(node); + } + + NestedColumnTranslator nestedColumnTranslator = new NestedColumnTranslator(tableScanNode.getAssignments(), tableScanNode.getTable()); + Map nestedColumns = dereferences.stream().collect(Collectors.toMap(Function.identity(), nestedColumnTranslator::toNestedColumn)); + + Map nestedColumnHandles = + metadata.getNestedColumnHandles(session, tableScanNode.getTable(), nestedColumns.values()) + .entrySet().stream() + .filter(entry -> !nestedColumnTranslator.columnHandleExists(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (nestedColumnHandles.isEmpty()) { + return context.defaultRewrite(node); + } + + ImmutableMap.Builder columnHandleBuilder = ImmutableMap.builder(); + columnHandleBuilder.putAll(tableScanNode.getAssignments()); + + // Use to replace expression in original dereference expression + ImmutableMap.Builder symbolExpressionBuilder = ImmutableMap.builder(); + for (Map.Entry entry : nestedColumnHandles.entrySet()) { + NestedColumn nestedColumn = entry.getKey(); + Expression expression = nestedColumnTranslator.toExpression(nestedColumn); + Symbol symbol = symbolAllocator.newSymbol(nestedColumn.getName(), extractType(expression)); + symbolExpressionBuilder.put(expression, symbol); + columnHandleBuilder.put(symbol, entry.getValue()); + } + ImmutableMap nestedColumnsMap = columnHandleBuilder.build(); + + TableScanNode newTableScan = new TableScanNode( + idAllocator.getNextId(), + tableScanNode.getTable(), + ImmutableList.copyOf(nestedColumnsMap.keySet()), + nestedColumnsMap, + tableScanNode.getEnforcedConstraint()); + + Rewriter rewriter = new Rewriter(symbolExpressionBuilder.build()); + Map assignments = node.getAssignments().entrySet().stream().collect( + Collectors.toMap(Map.Entry::getKey, entry -> ExpressionTreeRewriter.rewriteWith(rewriter, entry.getValue()))); + return new ProjectNode(idAllocator.getNextId(), newTableScan, Assignments.copyOf(assignments)); + } + + private class NestedColumnTranslator + { + private final Map symbolToColumnName; + private final Map columnNameToSymbol; + + NestedColumnTranslator(Map columnHandleMap, TableHandle tableHandle) + { + BiMap symbolToColumnName = HashBiMap.create(columnHandleMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> metadata.getColumnMetadata(session, tableHandle, entry.getValue()).getName()))); + this.symbolToColumnName = symbolToColumnName; + this.columnNameToSymbol = symbolToColumnName.inverse(); + } + + boolean columnHandleExists(NestedColumn nestedColumn) + { + return columnNameToSymbol.containsKey(nestedColumn.getName()); + } + + NestedColumn toNestedColumn(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitSubscriptExpression(SubscriptExpression node, Void context) + { + return null; + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void context) + { + process(node.getBase(), context); + builder.add(node.getField().getValue()); + return null; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, Void context) + { + Symbol baseName = Symbol.from(node); + Preconditions.checkArgument(symbolToColumnName.containsKey(baseName), "base [%s] doesn't exist in assignments [%s]", baseName, symbolToColumnName); + builder.add(symbolToColumnName.get(baseName)); + return null; + } + }.process(expression, null); + List names = builder.build(); + Preconditions.checkArgument(names.size() > 1, "names size is less than 0", names); + return new NestedColumn(names); + } + + Expression toExpression(NestedColumn nestedColumn) + { + Expression result = null; + for (String part : nestedColumn.getNames()) { + if (result == null) { + Preconditions.checkArgument(columnNameToSymbol.containsKey(part), "element %s doesn't exist in map %s", part, columnNameToSymbol); + result = columnNameToSymbol.get(part).toSymbolReference(); + } + else { + result = new DereferenceExpression(result, new Identifier(part)); + } + } + return result; + } + } + } + + // expression: msg_12.foo -> nestedColumn: msg.foo -> expression: msg_12.foo + + private static class Rewriter + extends ExpressionRewriter + { + private final Map map; + + Rewriter(Map map) + { + this.map = map; + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node)) { + return map.get(node).toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + + @Override + public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node)) { + return map.get(node).toSymbolReference(); + } + return super.rewriteSymbolReference(node, context, treeRewriter); + } + } + + public static Expression validDereferenceExpression(Expression expression) + { + //Preconditions.checkArgument(expression instanceof DereferenceExpression, "express must be dereference expression first"); + SubscriptExpression[] shortestSubscriptExp = new SubscriptExpression[1]; + boolean[] valid = new boolean[1]; + valid[0] = true; + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitSubscriptExpression(SubscriptExpression node, Void context) + { + shortestSubscriptExp[0] = node; + process(node.getBase(), context); + return null; + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void context) + { + valid[0] &= (node.getBase() instanceof SymbolReference || node.getBase() instanceof DereferenceExpression || node.getBase() instanceof SubscriptExpression); + process(node.getBase(), context); + return null; + } + }.process(expression, null); + if (valid[0]) { + return shortestSubscriptExp[0] == null ? expression : shortestSubscriptExp[0].getBase(); + } + else { + return null; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PushDownDereferenceExpression.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PushDownDereferenceExpression.java new file mode 100644 index 000000000000..e57165f65679 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PushDownDereferenceExpression.java @@ -0,0 +1,367 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.optimizations; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.prestosql.Session; +import io.prestosql.execution.warnings.WarningCollector; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.TypeProvider; +import io.prestosql.sql.planner.plan.AggregationNode; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SimplePlanRewriter; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.planner.plan.UnnestNode; +import io.prestosql.sql.planner.plan.ValuesNode; +import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.ExpressionRewriter; +import io.prestosql.sql.tree.ExpressionTreeRewriter; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.sql.planner.optimizations.MergeNestedColumn.prefixExist; +import static java.util.Objects.requireNonNull; + +public class PushDownDereferenceExpression + implements PlanOptimizer +{ + private final Metadata metadata; + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferenceExpression(Metadata metadata, TypeAnalyzer typeAnalyzer) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + Map expressionInfoMap = new HashMap<>(); + return SimplePlanRewriter.rewriteWith(new Optimizer(session, metadata, typeAnalyzer, symbolAllocator, idAllocator, warningCollector), plan, expressionInfoMap); + } + + private static class DereferenceReplacer + extends ExpressionRewriter + { + private final Map map; + + DereferenceReplacer(Map map) + { + this.map = map; + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node) && map.get(node).isFromValidSource()) { + return map.get(node).getSymbol().toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + } + + private static class Optimizer + extends SimplePlanRewriter> + { + private final Session session; + private final SymbolAllocator symbolAllocator; + private final TypeAnalyzer typeAnalyzer; + private final PlanNodeIdAllocator idAllocator; + private final Metadata metadata; + private final WarningCollector warningCollector; + + private Optimizer(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + this.session = session; + this.typeAnalyzer = typeAnalyzer; + this.metadata = metadata; + this.symbolAllocator = symbolAllocator; + this.idAllocator = idAllocator; + this.warningCollector = warningCollector; + } + + @Override + public PlanNode visitAggregation(AggregationNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + extractDereferenceInfos(node).forEach(expressionInfoMap::putIfAbsent); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + Map aggregations = new HashMap<>(); + for (Map.Entry symbolAggregationEntry : node.getAggregations().entrySet()) { + Symbol symbol = symbolAggregationEntry.getKey(); + AggregationNode.Aggregation oldAggregation = symbolAggregationEntry.getValue(); + AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), oldAggregation.getCall()), oldAggregation.getSignature(), oldAggregation.getMask()); + aggregations.put(symbol, newAggregation); + } + return new AggregationNode( + idAllocator.getNextId(), + child, + aggregations, + node.getGroupingSets(), + node.getPreGroupedSymbols(), + node.getStep(), + node.getHashSymbol(), + node.getGroupIdSymbol()); + } + + @Override + public PlanNode visitFilter(FilterNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + extractDereferenceInfos(node).forEach(expressionInfoMap::putIfAbsent); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + Expression predicate = ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), node.getPredicate()); + return new FilterNode(idAllocator.getNextId(), child, predicate); + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + // parentDereferenceInfos is used to find out passThroughSymbol. we will only pass those symbols that are needed by upstream + List parentDereferenceInfos = expressionInfoMap.entrySet().stream().map(Map.Entry::getValue).collect(Collectors.toList()); + Map newDereferences = extractDereferenceInfos(node); + newDereferences.forEach(expressionInfoMap::putIfAbsent); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + List passThroughSymbols = getUsedDereferenceInfo(node.getOutputSymbols(), parentDereferenceInfos).stream().filter(DereferenceInfo::isFromValidSource).map(DereferenceInfo::getSymbol).collect(Collectors.toList()); + + Assignments.Builder assignmentsBuilder = Assignments.builder(); + for (Map.Entry entry : node.getAssignments().entrySet()) { + assignmentsBuilder.put(entry.getKey(), ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), entry.getValue())); + } + assignmentsBuilder.putIdentities(passThroughSymbols); + ProjectNode newProjectNode = new ProjectNode(idAllocator.getNextId(), child, assignmentsBuilder.build()); + newDereferences.forEach(expressionInfoMap::remove); + return newProjectNode; + } + + @Override + public PlanNode visitTableScan(TableScanNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + List usedDereferenceInfo = getUsedDereferenceInfo(node.getOutputSymbols(), expressionInfoMap.values()); + if (!usedDereferenceInfo.isEmpty()) { + usedDereferenceInfo.forEach(DereferenceInfo::doesFromValidSource); + Map assignmentMap = usedDereferenceInfo.stream().collect(Collectors.toMap(DereferenceInfo::getSymbol, DereferenceInfo::getDereference)); + return new ProjectNode(idAllocator.getNextId(), node, Assignments.builder().putAll(assignmentMap).putIdentities(node.getOutputSymbols()).build()); + } + return node; + } + + @Override + public PlanNode visitValues(ValuesNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + List usedDereferenceInfo = getUsedDereferenceInfo(node.getOutputSymbols(), expressionInfoMap.values()); + if (!usedDereferenceInfo.isEmpty()) { + usedDereferenceInfo.forEach(DereferenceInfo::doesFromValidSource); + Map assignmentMap = usedDereferenceInfo.stream().collect(Collectors.toMap(DereferenceInfo::getSymbol, DereferenceInfo::getDereference)); + return new ProjectNode(idAllocator.getNextId(), node, Assignments.builder().putAll(assignmentMap).putIdentities(node.getOutputSymbols()).build()); + } + return node; + } + + @Override + public PlanNode visitJoin(JoinNode joinNode, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + extractDereferenceInfos(joinNode).forEach(expressionInfoMap::putIfAbsent); + + PlanNode leftNode = context.rewrite(joinNode.getLeft(), expressionInfoMap); + PlanNode rightNode = context.rewrite(joinNode.getRight(), expressionInfoMap); + + List equiJoinClauses = joinNode.getCriteria().stream() + .map(JoinNode.EquiJoinClause::toExpression) + .map(expr -> ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), expr)) + .map(this::getEquiJoinClause) + .collect(Collectors.toList()); + + Optional joinFilter = joinNode.getFilter().map(expression -> ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressionInfoMap), expression)); + + return new JoinNode( + joinNode.getId(), + joinNode.getType(), + leftNode, + rightNode, + equiJoinClauses, + ImmutableList.builder().addAll(leftNode.getOutputSymbols()).addAll(rightNode.getOutputSymbols()).build(), + joinFilter, + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType(), + joinNode.isSpillable()); + } + + @Override + public PlanNode visitUnnest(UnnestNode node, RewriteContext> context) + { + Map expressionInfoMap = context.get(); + List parentDereferenceInfos = expressionInfoMap.entrySet().stream().map(Map.Entry::getValue).collect(Collectors.toList()); + + PlanNode child = context.rewrite(node.getSource(), expressionInfoMap); + + List passThroughSymbols = getUsedDereferenceInfo(child.getOutputSymbols(), parentDereferenceInfos).stream().filter(DereferenceInfo::isFromValidSource).map(DereferenceInfo::getSymbol).collect(Collectors.toList()); + UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), child, ImmutableList.builder().addAll(node.getReplicateSymbols()).addAll(passThroughSymbols).build(), node.getUnnestSymbols(), node.getOrdinalitySymbol()); + + List unnestSymbols = unnestNode.getUnnestSymbols().entrySet().stream().flatMap(entry -> entry.getValue().stream()).collect(Collectors.toList()); + List dereferenceExpressionInfos = getUsedDereferenceInfo(unnestSymbols, expressionInfoMap.values()); + if (!dereferenceExpressionInfos.isEmpty()) { + dereferenceExpressionInfos.forEach(DereferenceInfo::doesFromValidSource); + Map assignmentMap = dereferenceExpressionInfos.stream().collect(Collectors.toMap(DereferenceInfo::getSymbol, DereferenceInfo::getDereference)); + return new ProjectNode(idAllocator.getNextId(), unnestNode, Assignments.builder().putAll(assignmentMap).putIdentities(unnestNode.getOutputSymbols()).build()); + } + return unnestNode; + } + + private List getUsedDereferenceInfo(List symbols, Collection dereferenceExpressionInfos) + { + Set symbolSet = symbols.stream().collect(Collectors.toSet()); + return dereferenceExpressionInfos.stream().filter(dereferenceExpressionInfo -> symbolSet.contains(dereferenceExpressionInfo.getBaseSymbol())).collect(Collectors.toList()); + } + + private JoinNode.EquiJoinClause getEquiJoinClause(Expression expression) + { + checkArgument(expression instanceof ComparisonExpression, "expression [%s] is not equal expression", expression); + ComparisonExpression comparisonExpression = (ComparisonExpression) expression; + return new JoinNode.EquiJoinClause(Symbol.from(comparisonExpression.getLeft()), Symbol.from(comparisonExpression.getRight())); + } + + private Type extractType(Expression expression) + { + Type type = typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression); + return type; + } + + private DereferenceInfo getDereferenceInfo(Expression expression) + { + Symbol symbol = symbolAllocator.newSymbol(expression, extractType(expression)); + Symbol base = Iterables.getOnlyElement(SymbolsExtractor.extractAll(expression)); + return new DereferenceInfo(expression, symbol, base); + } + + private List extractDereference(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor>() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableList.Builder context) + { + context.add(node); + return null; + } + }.process(expression, builder); + return builder.build(); + } + + private Map extractDereferenceInfos(PlanNode node) + { + Set allExpressions = ExpressionExtractor.extractExpressionsNonRecursive(node).stream() + .flatMap(expression -> extractDereference(expression).stream()) + .map(MergeNestedColumn::validDereferenceExpression).filter(Objects::nonNull).collect(toImmutableSet()); + + return allExpressions.stream() + .filter(expression -> !prefixExist(expression, allExpressions)) + .filter(expression -> expression instanceof DereferenceExpression) + .distinct() + .map(this::getDereferenceInfo) + .collect(Collectors.toMap(DereferenceInfo::getDereference, Function.identity())); + } + } + + private static class DereferenceInfo + { + // e.g. for dereference expression msg.foo[1].bar, base is "msg", newSymbol is new assigned symbol to replace this dereference expression + private final Expression dereferenceExpression; + private final Symbol symbol; + private final Symbol baseSymbol; + + // fromValidSource is used to check whether the dereference expression is from either TableScan or Unnest + // it will be false for following node therefore we won't rewrite: + // Project[expr_1 := "max_by"."field1"] + // - Aggregate[max_by := "max_by"("expr", "app_rating")] => [max_by:row(field0 varchar, field1 varchar)] + private boolean fromValidSource; + + public DereferenceInfo(Expression dereferenceExpression, Symbol symbol, Symbol baseSymbol) + { + this.dereferenceExpression = requireNonNull(dereferenceExpression); + this.symbol = requireNonNull(symbol); + this.baseSymbol = requireNonNull(baseSymbol); + this.fromValidSource = false; + } + + public Symbol getSymbol() + { + return symbol; + } + + public Symbol getBaseSymbol() + { + return baseSymbol; + } + + public Expression getDereference() + { + return dereferenceExpression; + } + + public boolean isFromValidSource() + { + return fromValidSource; + } + + public void doesFromValidSource() + { + fromValidSource = true; + } + + @Override + public String toString() + { + return String.format("(%s, %s, %s)", dereferenceExpression, symbol, baseSymbol); + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java b/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java index c9e4e4257818..baae35316aa0 100644 --- a/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java +++ b/presto-main/src/main/java/io/prestosql/testing/TestingMetadata.java @@ -55,6 +55,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.spi.StandardErrorCode.ALREADY_EXISTS; import static java.util.Objects.requireNonNull; @@ -65,6 +66,16 @@ public class TestingMetadata private final ConcurrentMap tables = new ConcurrentHashMap<>(); private final ConcurrentMap views = new ConcurrentHashMap<>(); + protected ConcurrentMap getTables() + { + return tables; + } + + protected ConcurrentMap getViews() + { + return views; + } + @Override public List listSchemaNames(ConnectorSession session) { @@ -291,7 +302,7 @@ public void clear() tables.clear(); } - private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) + protected static SchemaTableName getTableName(ConnectorTableHandle tableHandle) { requireNonNull(tableHandle, "tableHandle is null"); checkArgument(tableHandle instanceof TestingTableHandle, "tableHandle is not an instance of TestingTableHandle"); @@ -398,5 +409,11 @@ public int hashCode() { return Objects.hash(name, ordinalPosition, type); } + + @Override + public String toString() + { + return toStringHelper(this).add("name", name).add("position", ordinalPosition).add("type", type).toString(); + } } } diff --git a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java index f6defbde6855..19a9aa82a5f8 100644 --- a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java @@ -16,6 +16,7 @@ import io.airlift.slice.Slice; import io.prestosql.Session; import io.prestosql.connector.CatalogName; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.block.BlockEncodingSerde; import io.prestosql.spi.connector.CatalogSchemaName; import io.prestosql.spi.connector.ColumnHandle; @@ -170,6 +171,12 @@ public Map getColumnHandles(Session session, TableHandle t throw new UnsupportedOperationException(); } + @Override + public Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection dereferences) + { + throw new UnsupportedOperationException(); + } + @Override public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java new file mode 100644 index 000000000000..0cb916e114a5 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import org.testng.annotations.Test; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; + +public class TestDereferencePushDown + extends BasePlanTest +{ + private static final String VALUES = "(values ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))"; + + @Test + public void testPushDownDereferencesThroughJoin() + { + assertPlan(" with t1 as ( select * from " + VALUES + " as t (msg) ) select b.msg.x from t1 a, t1 b where a.msg.y = b.msg.y", + output(ImmutableList.of("x"), + join(INNER, ImmutableList.of(equiJoinClause("left_y", "right_y")), + anyTree( + project(ImmutableMap.of("left_y", expression("field.y")), + values("field")) + ), anyTree( + project(ImmutableMap.of("right_y", expression("field1.y"), "x", expression("field1.x")), + values("field1")))))); + } + + @Test + public void testPushDownDereferencesInCase() + { + // Test dereferences in then clause will not be eagerly evaluated. + String statement = "with t as (select * from (values cast(array[CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)),ROW(1, 2.0)] as array)) as t (arr) ) " + + "select case when cardinality(arr) > cast(0 as bigint) then arr[cast(1 as bigint)].x end from t"; + assertPlan(statement, + output(ImmutableList.of("x"), + project(ImmutableMap.of("x", expression("case when cardinality(field) > bigint '0' then field[bigint '1'].x end")), values("field")))); + } + + @Test + public void testPushDownDereferencesThroughFilter() + { + assertPlan(" with t1 as ( select * from " + VALUES + " as t (msg) ) select a.msg.y from t1 a join t1 b on a.msg.y = b.msg.y where a.msg.x > bigint '5'", + output(ImmutableList.of("left_y"), + join(INNER, ImmutableList.of(equiJoinClause("left_y", "right_y")), + anyTree( + project(ImmutableMap.of("left_y", expression("field.y")), + filter("field.x > bigint '5'", values("field"))) + ), anyTree( + project(ImmutableMap.of("right_y", expression("field1.y")), + values("field1")))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestMergeNestedColumns.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestMergeNestedColumns.java new file mode 100644 index 000000000000..9a76ca0e5af0 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestMergeNestedColumns.java @@ -0,0 +1,255 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.NestedColumn; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.ColumnMetadata; +import io.prestosql.spi.connector.Connector; +import io.prestosql.spi.connector.ConnectorContext; +import io.prestosql.spi.connector.ConnectorFactory; +import io.prestosql.spi.connector.ConnectorHandleResolver; +import io.prestosql.spi.connector.ConnectorMetadata; +import io.prestosql.spi.connector.ConnectorPageSinkProvider; +import io.prestosql.spi.connector.ConnectorPageSource; +import io.prestosql.spi.connector.ConnectorPageSourceProvider; +import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.connector.ConnectorSplit; +import io.prestosql.spi.connector.ConnectorSplitManager; +import io.prestosql.spi.connector.ConnectorTableHandle; +import io.prestosql.spi.connector.ConnectorTableMetadata; +import io.prestosql.spi.connector.ConnectorTransactionHandle; +import io.prestosql.spi.connector.FixedPageSource; +import io.prestosql.spi.connector.SchemaTableName; +import io.prestosql.spi.transaction.IsolationLevel; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.VarcharType; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import io.prestosql.testing.LocalQueryRunner; +import io.prestosql.testing.TestingHandleResolver; +import io.prestosql.testing.TestingMetadata; +import io.prestosql.testing.TestingPageSinkProvider; +import io.prestosql.testing.TestingSplitManager; +import org.testng.annotations.Test; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static io.prestosql.spi.type.RowType.field; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.prestosql.testing.TestingSession.testSessionBuilder; +import static java.util.Objects.requireNonNull; + +public class TestMergeNestedColumns + extends BasePlanTest +{ + private static final Type MSG_TYPE = RowType.from(ImmutableList.of(field("x", VarcharType.VARCHAR), field("y", VarcharType.VARCHAR))); + + public TestMergeNestedColumns() + { + super(TestMergeNestedColumns::createQueryRunner); + } + + @Test + public void testSelectDereference() + { + assertPlan("select foo.x, foo.y, bar.x, bar.y from nested_column_table", + output(ImmutableList.of("foo_x", "foo_y", "bar_x", "bar_y"), + project(ImmutableMap.of("foo_x", expression("foo_x"), "foo_y", expression("foo_y"), "bar_x", expression("bar.x"), "bar_y", expression("bar.y")), + tableScan("nested_column_table", ImmutableMap.of("foo_x", "foo.x", "foo_y", "foo.y", "bar", "bar"))))); + } + + @Test + public void testSelectDereferenceAndParentDoesNotFire() + { + assertPlan("select foo.x, foo.y, foo from nested_column_table", + output(ImmutableList.of("foo_x", "foo_y", "foo"), + project(ImmutableMap.of("foo_x", expression("foo.x"), "foo_y", expression("foo.y"), "foo", expression("foo")), + tableScan("nested_column_table", ImmutableMap.of("foo", "foo"))))); + } + + private static LocalQueryRunner createQueryRunner() + { + String schemaName = "test-schema"; + String catalogName = "test"; + TableInfo regularTable = new TableInfo(new SchemaTableName(schemaName, "regular_table"), + ImmutableList.of(new ColumnMetadata("dummy_column", VarcharType.VARCHAR)), ImmutableMap.of()); + + TableInfo nestedColumnTable = new TableInfo(new SchemaTableName(schemaName, "nested_column_table"), + ImmutableList.of(new ColumnMetadata("foo", MSG_TYPE), new ColumnMetadata("bar", MSG_TYPE)), ImmutableMap.of( + new NestedColumn(ImmutableList.of("foo", "x")), 0, + new NestedColumn(ImmutableList.of("foo", "y")), 0)); + + ImmutableList tableInfos = ImmutableList.of(regularTable, nestedColumnTable); + + LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder() + .setCatalog(catalogName) + .setSchema(schemaName) + .build()); + queryRunner.createCatalog(catalogName, new TestConnectorFactory(new TestMetadata(tableInfos)), ImmutableMap.of()); + return queryRunner; + } + + private static class TestConnectorFactory + implements ConnectorFactory + { + private final TestMetadata metadata; + + public TestConnectorFactory(TestMetadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public String getName() + { + return "test"; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new TestingHandleResolver(); + } + + @Override + public Connector create(String connectorId, Map config, ConnectorContext context) + { + return new TestConnector(metadata); + } + } + + private enum TransactionInstance + implements ConnectorTransactionHandle + { + INSTANCE + } + + private static class TestConnector + implements Connector + { + private final ConnectorMetadata metadata; + + private TestConnector(ConnectorMetadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return TransactionInstance.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return new TestingSplitManager(ImmutableList.of()); + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return new ConnectorPageSourceProvider() + { + @Override + public ConnectorPageSource createPageSource(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorSplit split, ConnectorTableHandle table, List columns) + { + return new FixedPageSource(ImmutableList.of()); + } + }; + } + + @Override + public ConnectorPageSinkProvider getPageSinkProvider() + { + return new TestingPageSinkProvider(); + } + } + + private static class TestMetadata + extends TestingMetadata + { + private final List tableInfos; + + TestMetadata(List tableInfos) + { + this.tableInfos = requireNonNull(tableInfos, "tableinfos is null"); + insertTables(); + } + + private void insertTables() + { + for (TableInfo tableInfo : tableInfos) { + getTables().put(tableInfo.getSchemaTableName(), tableInfo.getTableMetadata()); + } + } + + @Override + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection dereferences) + { + requireNonNull(tableHandle, "tableHandle is null"); + SchemaTableName tableName = getTableName(tableHandle); + return tableInfos.stream().filter(tableInfo -> tableInfo.getSchemaTableName().equals(tableName)).map(TableInfo::getNestedColumnHandle).findFirst().orElse(ImmutableMap.of()); + } + } + + private static class TableInfo + { + private final SchemaTableName schemaTableName; + private final List columnMetadatas; + private final Map nestedColumns; + + public TableInfo(SchemaTableName schemaTableName, List columnMetadata, Map nestedColumns) + { + this.schemaTableName = schemaTableName; + this.columnMetadatas = columnMetadata; + this.nestedColumns = nestedColumns; + } + + SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + ConnectorTableMetadata getTableMetadata() + { + return new ConnectorTableMetadata(schemaTableName, columnMetadatas); + } + + Map getNestedColumnHandle() + { + return nestedColumns.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> { + Preconditions.checkArgument(entry.getValue() >= 0 && entry.getValue() < columnMetadatas.size(), "index is not valid"); + NestedColumn nestedColumn = entry.getKey(); + return new TestingMetadata.TestingColumnHandle(nestedColumn.getName(), entry.getValue(), VarcharType.VARCHAR); + })); + } + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java index 2ed59bc9795b..5cc7e874f9b1 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ColumnReference.java @@ -13,6 +13,8 @@ */ package io.prestosql.sql.planner.assertions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import io.prestosql.Session; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.TableHandle; @@ -23,6 +25,7 @@ import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.TableScanNode; +import java.util.AbstractMap; import java.util.Map; import java.util.Optional; @@ -91,8 +94,8 @@ private Optional getAssignedSymbol(Map assignments private Optional getColumnHandle(TableHandle tableHandle, Session session, Metadata metadata) { - return metadata.getColumnHandles(session, tableHandle).entrySet() - .stream() + return Streams.concat(metadata.getColumnHandles(session, tableHandle).entrySet().stream(), + metadata.getNestedColumnHandles(session, tableHandle, ImmutableList.of()).entrySet().stream().map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey().getName(), entry.getValue()))) .filter(entry -> columnName.equals(entry.getKey())) .map(Map.Entry::getValue) .findFirst(); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java index 00fa3d9f9c04..bf60645ba5da 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java @@ -21,6 +21,7 @@ import io.prestosql.sql.tree.CoalesceExpression; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.DecimalLiteral; +import io.prestosql.sql.tree.DereferenceExpression; import io.prestosql.sql.tree.DoubleLiteral; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.FunctionCall; @@ -34,8 +35,10 @@ import io.prestosql.sql.tree.Node; import io.prestosql.sql.tree.NotExpression; import io.prestosql.sql.tree.NullLiteral; +import io.prestosql.sql.tree.SearchedCaseExpression; import io.prestosql.sql.tree.SimpleCaseExpression; import io.prestosql.sql.tree.StringLiteral; +import io.prestosql.sql.tree.SubscriptExpression; import io.prestosql.sql.tree.SymbolReference; import io.prestosql.sql.tree.TryExpression; import io.prestosql.sql.tree.WhenClause; @@ -97,6 +100,20 @@ protected Boolean visitTryExpression(TryExpression actual, Node expected) return process(actual.getInnerExpression(), ((TryExpression) expected).getInnerExpression()); } + @Override + protected Boolean visitDereferenceExpression(DereferenceExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof DereferenceExpression)) { + return false; + } + + DereferenceExpression expected = (DereferenceExpression) expectedExpression; + if (actual.getField().equals(expected.getField())) { + return process(actual.getBase(), expected.getBase()); + } + return false; + } + @Override protected Boolean visitCast(Cast actual, Node expectedExpression) { @@ -125,6 +142,18 @@ protected Boolean visitIsNullPredicate(IsNullPredicate actual, Node expectedExpr return process(actual.getValue(), expected.getValue()); } + @Override + protected Boolean visitSubscriptExpression(SubscriptExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SubscriptExpression)) { + return false; + } + + SubscriptExpression expected = (SubscriptExpression) expectedExpression; + + return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex()); + } + @Override protected Boolean visitIsNotNullPredicate(IsNotNullPredicate actual, Node expectedExpression) { @@ -307,6 +336,17 @@ protected Boolean visitNotExpression(NotExpression actual, Node expected) return false; } + @Override + protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SearchedCaseExpression)) { + return false; + } + + SearchedCaseExpression expected = (SearchedCaseExpression) expectedExpression; + return process(actual.getDefaultValue(), expected.getDefaultValue()) && process(actual.getWhenClauses(), expected.getWhenClauses()); + } + @Override protected Boolean visitSymbolReference(SymbolReference actual, Node expected) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java index 2c6479d2f2bb..f4a72ab28a36 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java @@ -323,6 +323,11 @@ public static PlanMatchPattern strictOutput(List outputs, PlanMatchPatte return output(outputs, source).withExactOutputs(outputs); } + public static PlanMatchPattern unnest(List replicateSymbols, Map> unnestSymbols, Optional ordinalitySymbol, PlanMatchPattern source) + { + return unnest(source).with(new UnnestMatcher(replicateSymbols, unnestSymbols, ordinalitySymbol)); + } + public static PlanMatchPattern project(PlanMatchPattern source) { return node(ProjectNode.class, source); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java new file mode 100644 index 000000000000..18afa93c8e80 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.assertions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.prestosql.Session; +import io.prestosql.cost.StatsProvider; +import io.prestosql.metadata.Metadata; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.UnnestNode; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkState; +import static io.prestosql.sql.planner.assertions.MatchResult.match; +import static java.util.Objects.requireNonNull; + +public class UnnestMatcher + implements Matcher +{ + private final List replicateSymbols; + private final Map> unnestSymbols; + private final Optional ordinalitySymbol; + + public UnnestMatcher(List replicateSymbols, Map> unnestSymbols, Optional ordinalitySymbol) + { + this.replicateSymbols = requireNonNull(replicateSymbols, "replicateSymbols is null"); + this.unnestSymbols = requireNonNull(unnestSymbols); + this.ordinalitySymbol = requireNonNull(ordinalitySymbol); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof UnnestNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + UnnestNode unnestNode = (UnnestNode) node; + + ImmutableList.Builder aliasBuilder = ImmutableList.builder().addAll(replicateSymbols).addAll(Iterables.concat(unnestSymbols.values())); + ordinalitySymbol.ifPresent(aliasBuilder::add); + List alias = aliasBuilder.build(); + + if (alias.size() != unnestNode.getOutputSymbols().size()) { + return MatchResult.NO_MATCH; + } + + SymbolAliases.Builder builder = SymbolAliases.builder().putAll(symbolAliases); + IntStream.range(0, alias.size()).forEach(i -> builder.put(alias.get(i), unnestNode.getOutputSymbols().get(i).toSymbolReference())); + return match(builder.build()); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java index de3c18b66f24..eb478cbae495 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/RuleTester.java @@ -43,6 +43,12 @@ public class RuleTester private final Metadata metadata; private final Session session; + + public LocalQueryRunner getQueryRunner() + { + return queryRunner; + } + private final LocalQueryRunner queryRunner; private final TransactionManager transactionManager; private final SplitManager splitManager; diff --git a/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java b/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java index 59a4ba55a755..d267608bd182 100644 --- a/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java +++ b/presto-parquet/src/main/java/io/prestosql/parquet/ParquetTypeUtils.java @@ -13,6 +13,9 @@ */ package io.prestosql.parquet; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.type.DecimalType; import io.prestosql.spi.type.Type; import org.apache.parquet.column.Encoding; @@ -24,6 +27,7 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.PrimitiveColumnIO; import org.apache.parquet.schema.DecimalMetadata; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import java.util.Arrays; @@ -34,6 +38,7 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterators.getOnlyElement; import static org.apache.parquet.schema.OriginalType.DECIMAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; @@ -185,14 +190,14 @@ public static ParquetEncoding getParquetEncoding(Encoding encoding) } } - public static org.apache.parquet.schema.Type getParquetTypeByName(String columnName, MessageType messageType) + public static org.apache.parquet.schema.Type getParquetTypeByName(String columnName, GroupType groupType) { - if (messageType.containsField(columnName)) { - return messageType.getType(columnName); + if (groupType.containsField(columnName)) { + return groupType.getType(columnName); } // parquet is case-sensitive, but hive is not. all hive columns get converted to lowercase // check for direct match above but if no match found, try case-insensitive match - for (org.apache.parquet.schema.Type type : messageType.getFields()) { + for (org.apache.parquet.schema.Type type : groupType.getFields()) { if (type.getName().equalsIgnoreCase(columnName)) { return type; } @@ -262,4 +267,37 @@ public static long getShortDecimalValue(byte[] bytes) return value; } + + public static org.apache.parquet.schema.Type getNestedColumnType(GroupType baseType, NestedColumn nestedColumn) + { + Preconditions.checkArgument(nestedColumn.getNames().size() >= 1, "fields size is less than 1"); + + ImmutableList.Builder typeBuilder = ImmutableList.builder(); + org.apache.parquet.schema.Type parentType = baseType; + + for (String field : nestedColumn.getNames()) { + org.apache.parquet.schema.Type childType = getParquetTypeByName(field, parentType.asGroupType()); + if (childType == null) { + return null; + } + typeBuilder.add(childType); + parentType = childType; + } + List typeChain = typeBuilder.build(); + + if (typeChain.isEmpty()) { + return null; + } + else if (typeChain.size() == 1) { + return getOnlyElement(typeChain.iterator()); + } + else { + org.apache.parquet.schema.Type messageType = typeChain.get(typeChain.size() - 1); + for (int i = typeChain.size() - 2; i >= 0; --i) { + GroupType groupType = typeChain.get(i).asGroupType(); + messageType = new MessageType(groupType.getName(), ImmutableList.of(messageType)); + } + return messageType; + } + } } diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java b/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java index d5401a8a1b5e..5a812441eafc 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/DereferenceExpression.java @@ -20,6 +20,7 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ENGLISH; public class DereferenceExpression extends Expression @@ -118,12 +119,12 @@ public boolean equals(Object o) } DereferenceExpression that = (DereferenceExpression) o; return Objects.equals(base, that.base) && - Objects.equals(field, that.field); + Objects.equals(field.getValue().toLowerCase(ENGLISH), that.field.getValue().toLowerCase(ENGLISH)); } @Override public int hashCode() { - return Objects.hash(base, field); + return Objects.hash(base, field.getValue().toLowerCase(ENGLISH)); } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/NestedColumn.java b/presto-spi/src/main/java/io/prestosql/spi/NestedColumn.java new file mode 100644 index 000000000000..0806d903a8e0 --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/NestedColumn.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.prestosql.spi.type.Type; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static java.util.Objects.requireNonNull; + +public class NestedColumn +{ + private final List names; + + @JsonCreator + public NestedColumn(@JsonProperty("names") List names) + { + this.names = requireNonNull(names); + } + + @JsonProperty + public List getNames() + { + return names; + } + + public String getBase() + { + return names.get(0); + } + + public List getRest() + { + // TODO assert size > 1; + return names.subList(1, names.size()); + } + + public String getName() + { + return names.stream().collect(Collectors.joining(".")); + } + + @JsonProperty + public Type getType() + { + return null; + } + + @Override + public int hashCode() + { + return Objects.hash(names); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + NestedColumn other = (NestedColumn) obj; + return Objects.equals(this.names, other.names); + } + + @Override + public String toString() + { + StringBuilder sb = new StringBuilder("NestedColumns{"); + sb.append("name='").append(getName()).append('\''); + sb.append('}'); + return sb.toString(); + } +} diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java index cc593d7c2d00..2546b9aa14f2 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java @@ -14,6 +14,7 @@ package io.prestosql.spi.connector; import io.airlift.slice.Slice; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.PrestoException; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.security.GrantInfo; @@ -27,6 +28,7 @@ import javax.annotation.Nullable; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -629,4 +631,9 @@ default Optional> applyFilter( { return Optional.empty(); } + + default Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection dereferences) + { + return new HashMap<>(); + } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index 0b3f717456bc..8840ec11aadf 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -14,6 +14,7 @@ package io.prestosql.spi.connector.classloader; import io.airlift.slice.Slice; +import io.prestosql.spi.NestedColumn; import io.prestosql.spi.classloader.ThreadContextClassLoader; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.ColumnMetadata; @@ -583,4 +584,11 @@ public Optional> applyFilter(C return delegate.applyFilter(session, table, constraint); } } + + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection dereferences) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getNestedColumnHandles(session, tableHandle, dereferences); + } + } }