From 0edff2e9c16f3cb2d12e2647f6d334984fedd394 Mon Sep 17 00:00:00 2001 From: Thang Long VU Date: Tue, 9 Apr 2024 19:40:10 +0200 Subject: [PATCH] Update --- .../sql/DeltaSparkSessionExtension.scala | 3 + .../delta/ColumnWithDefaultExprUtils.scala | 6 + .../spark/sql/delta/DeltaColumnMapping.scala | 6 +- .../sql/delta/DeltaParquetFileFormat.scala | 20 +- .../spark/sql/delta/GenerateRowIDs.scala | 139 +++++++ .../org/apache/spark/sql/delta/RowId.scala | 115 +++++- .../apache/spark/sql/delta/RowTracking.scala | 12 + .../sql/delta/files/TransactionalWrite.scala | 5 + .../schema/ImplicitMetadataOperation.scala | 16 +- .../spark/sql/delta/schema/SchemaUtils.scala | 4 + .../delta/DeltaParquetFileFormatSuite.scala | 1 + .../sql/delta/rowid/GenerateRowIDsSuite.scala | 172 ++++++++ .../spark/sql/delta/rowid/RowIdSuite.scala | 374 +++++++++++++++++- .../sql/delta/rowid/RowIdTestUtils.scala | 33 +- .../sql/delta/schema/SchemaUtilsSuite.scala | 23 ++ 15 files changed, 915 insertions(+), 14 deletions(-) create mode 100644 spark/src/main/scala/org/apache/spark/sql/delta/GenerateRowIDs.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/delta/rowid/GenerateRowIDsSuite.scala diff --git a/spark/src/main/scala/io/delta/sql/DeltaSparkSessionExtension.scala b/spark/src/main/scala/io/delta/sql/DeltaSparkSessionExtension.scala index 19e8493b551..1f4fca352b5 100644 --- a/spark/src/main/scala/io/delta/sql/DeltaSparkSessionExtension.scala +++ b/spark/src/main/scala/io/delta/sql/DeltaSparkSessionExtension.scala @@ -125,6 +125,9 @@ class DeltaSparkSessionExtension extends (SparkSessionExtensions => Unit) { extensions.injectPostHocResolutionRule { session => PostHocResolveUpCast(session) } + + extensions.injectPlanNormalizationRule { _ => GenerateRowIDs } + // We don't use `injectOptimizerRule` here as we won't want to apply further optimizations after // `PrepareDeltaScan`. // For example, `ConstantFolding` will break unit tests in `OptimizeGeneratedColumnSuite`. diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala index 55b3e658784..e47abb3d19e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala @@ -153,6 +153,12 @@ object ColumnWithDefaultExprUtils extends DeltaLogging { } } selectExprs = selectExprs ++ cdcSelectExprs + + val rowIdExprs = data.queryExecution.analyzed.output + .filter(RowId.RowIdMetadataAttribute.isRowIdColumn) + .map(new Column(_)) + selectExprs = selectExprs ++ rowIdExprs + val newData = queryExecution match { case incrementalExecution: IncrementalExecution => selectFromStreamingDataFrame(incrementalExecution, data, selectExprs: _*) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala index 73cfac97721..49c3c0dd075 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaColumnMapping.scala @@ -20,6 +20,7 @@ import java.util.{Locale, UUID} import scala.collection.mutable +import org.apache.spark.sql.delta.RowId.RowIdMetadataStructField import org.apache.spark.sql.delta.actions.{Metadata, Protocol} import org.apache.spark.sql.delta.commands.cdc.CDCReader import org.apache.spark.sql.delta.metering.DeltaLogging @@ -78,8 +79,9 @@ trait DeltaColumnMappingBase extends DeltaLogging { val supportedModes: Set[DeltaColumnMappingMode] = Set(IdMapping, NoMapping, NameMapping) - def isInternalField(field: StructField): Boolean = DELTA_INTERNAL_COLUMNS - .contains(field.name.toLowerCase(Locale.ROOT)) + def isInternalField(field: StructField): Boolean = + DELTA_INTERNAL_COLUMNS.contains(field.name.toLowerCase(Locale.ROOT)) || + RowIdMetadataStructField.isRowIdColumn(field) def satisfiesColumnMappingProtocol(protocol: Protocol): Boolean = protocol.isFeatureSupported(ColumnMappingTableFeature) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaParquetFileFormat.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaParquetFileFormat.scala index 6938256bf6f..83fcee9692d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaParquetFileFormat.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaParquetFileFormat.scala @@ -55,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration case class DeltaParquetFileFormat( protocol: Protocol, metadata: Metadata, + nullableRowTrackingFields: Boolean = false, isSplittable: Boolean = true, disablePushDowns: Boolean = false, tablePath: Option[String] = None, @@ -197,12 +198,21 @@ case class DeltaParquetFileFormat( } override def metadataSchemaFields: Seq[StructField] = { - // Parquet reader in Spark has a bug where a file containing 2b+ rows in a single rowgroup - // causes it to run out of the `Integer` range (TODO: Create a SPARK issue) + val rowTrackingFields = + RowTracking.createMetadataStructFields(protocol, metadata, nullableRowTrackingFields) + // TODO(SPARK-47731): Parquet reader in Spark has a bug where a file containing 2b+ rows + // in a single rowgroup causes it to run out of the `Integer` range. // For Delta Parquet readers don't expose the row_index field as a metadata field. - super.metadataSchemaFields.filter(field => field != ParquetFileFormat.ROW_INDEX_FIELD) ++ - RowId.createBaseRowIdField(protocol, metadata) ++ - DefaultRowCommitVersion.createDefaultRowCommitVersionField(protocol, metadata) + if (!RowId.isEnabled(protocol, metadata)) { + super.metadataSchemaFields.filter(_ != ParquetFileFormat.ROW_INDEX_FIELD) + } else { + // It is fine to expose the row_index field as a metadata field when Row Tracking + // is enabled because it is needed to generate the Row ID field, and it is not a + // big problem if we use 2b+ rows in a single rowgroup, it will throw an exception and + // we can then use less rows per rowgroup. Also, 2b+ rows in a single rowgroup is + // not a common use case. + super.metadataSchemaFields ++ rowTrackingFields + } } override def prepareWrite( diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/GenerateRowIDs.scala b/spark/src/main/scala/org/apache/spark/sql/delta/GenerateRowIDs.scala new file mode 100644 index 00000000000..fd464d7986e --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/GenerateRowIDs.scala @@ -0,0 +1,139 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION +import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.types.StructType + +/** + * This rule adds a Project on top of Delta tables that support the Row tracking table feature to + * provide a default generated Row ID for rows that don't have them materialized in the data file. + */ +object GenerateRowIDs extends Rule[LogicalPlan] { + + /** + * Matcher for a scan on a Delta table that has Row tracking enabled. + */ + private object DeltaScanWithRowTrackingEnabled { + def unapply(plan: LogicalPlan): Option[LogicalRelation] = plan match { + case scan @ LogicalRelation(relation: HadoopFsRelation, _, _, _) => + relation.fileFormat match { + case format: DeltaParquetFileFormat + if RowTracking.isEnabled(format.protocol, format.metadata) => Some(scan) + case _ => None + } + case _ => None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput { + case DeltaScanWithRowTrackingEnabled(scan) => + // While Row IDs are non-nullable, we'll use the Row ID attributes to read + // the materialized values from now on, which can be null. We make + // the materialized Row ID attributes nullable in the scan here. + + // Update nullability in the scan `metadataOutput` by updating the delta file format. + val baseRelation = scan.relation.asInstanceOf[HadoopFsRelation] + val newFileFormat = baseRelation.fileFormat match { + case format: DeltaParquetFileFormat => + format.copy(nullableRowTrackingFields = true) + } + val newBaseRelation = baseRelation.copy(fileFormat = newFileFormat)(baseRelation.sparkSession) + + // Update the output metadata column's data type (now with nullable row tracking fields). + val newOutput = scan.output.map { + case MetadataAttributeWithLogicalName(metadata, FileFormat.METADATA_NAME) => + metadata.withDataType(newFileFormat.createFileMetadataCol().dataType) + case other => other + } + val newScan = scan.copy(relation = newBaseRelation, output = newOutput) + newScan.copyTagsFrom(scan) + + // Add projection with row tracking column expressions. + val updatedAttributes = mutable.Buffer.empty[(Attribute, Attribute)] + val projectList = newOutput.map { + case MetadataAttributeWithLogicalName(metadata, FileFormat.METADATA_NAME) => + val updatedMetadata = metadataWithRowTrackingColumnsProjection(metadata) + updatedAttributes += metadata -> updatedMetadata.toAttribute + updatedMetadata + case other => other + } + Project(projectList = projectList, child = newScan) -> updatedAttributes.toSeq + case o => + val newPlan = o.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { + // Recurse into subquery plans. Similar to how [[transformUpWithSubqueries]] works except + // that it allows us to still use [[transformUpWithNewOutput]] on subquery plans to + // correctly update references to the metadata attribute when going up the plan. + // Get around type erasure by explicitly checking the plan type and removing warning. + case planExpression: PlanExpression[LogicalPlan @unchecked] + if planExpression.plan.isInstanceOf[LogicalPlan] => + planExpression.withNewPlan(apply(planExpression.plan)) + } + newPlan -> Nil + } + + /** + * Expression that reads the Row IDs from the materialized Row ID column if the value is + * present and returns the default generated Row ID using the file's base Row ID and current row + * index if not: + * coalesce(_metadata.row_id, _metadata.base_row_id + _metadata.row_index). + */ + private def rowIdExpr(metadata: AttributeReference): Expression = { + Coalesce(Seq( + getField(metadata, RowId.ROW_ID), + Add( + getField(metadata, RowId.BASE_ROW_ID), + getField(metadata, ParquetFileFormat.ROW_INDEX)))) + } + + /** + * Extract a field from the metadata column. + */ + private def getField(metadata: AttributeReference, name: String): GetStructField = { + ExtractValue(metadata, Literal(name), conf.resolver) match { + case field: GetStructField => field + case _ => + throw new IllegalStateException(s"The metadata column '${metadata.name}' is not a struct.") + } + } + + /** + * Create a new metadata struct where the Row ID values are populated using + * the materialized values if present, or the default Row ID values if not. + */ + private def metadataWithRowTrackingColumnsProjection(metadata: AttributeReference) + : NamedExpression = { + val metadataFields = metadata.dataType.asInstanceOf[StructType].map { + case field if field.name == RowId.ROW_ID => + field -> rowIdExpr(metadata) + case field => + field -> getField(metadata, field.name) + }.flatMap { case (oldField, newExpr) => + // Propagate the type metadata from the old fields to the new fields. + val newField = Alias(newExpr, oldField.name)(explicitMetadata = Some(oldField.metadata)) + Seq(Literal(oldField.name), newField) + } + Alias(CreateNamedStruct(metadataFields), metadata.name)() + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/RowId.scala b/spark/src/main/scala/org/apache/spark/sql/delta/RowId.scala index c68497ad8a7..4739801c85b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/RowId.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/RowId.scala @@ -20,7 +20,11 @@ import org.apache.spark.sql.delta.actions.{Action, AddFile, DomainMetadata, Meta import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils.propertyKey import org.apache.spark.sql.util.ScalaExtensions._ -import org.apache.spark.sql.catalyst.expressions.FileSourceConstantMetadataStructField +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, FileSourceConstantMetadataStructField, FileSourceGeneratedMetadataStructField} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.types import org.apache.spark.sql.types.{DataType, LongType, MetadataBuilder, StructField} @@ -155,4 +159,113 @@ object RowId { Option.when(RowId.isEnabled(protocol, metadata)) { BaseRowIdMetadataStructField() } + + /** Row ID column name */ + val ROW_ID = "row_id" + + val QUALIFIED_COLUMN_NAME = s"${FileFormat.METADATA_NAME}.${ROW_ID}" + + /** Column metadata to be used in conjunction [[QUALIFIED_COLUMN_NAME]] to mark row id columns */ + def columnMetadata(materializedColumnName: String): types.Metadata = + RowIdMetadataStructField.metadata(materializedColumnName) + + /** + * The field readers can use to access the generated row id column. The scanner's internal column + * name is obtained from the table's metadata. + */ + def createRowIdField(protocol: Protocol, metadata: Metadata, nullable: Boolean) + : Option[StructField] = + MaterializedRowId.getMaterializedColumnName(protocol, metadata) + .map(RowIdMetadataStructField(_, nullable)) + + /* + * A specialization of [[FileSourceGeneratedMetadataStructField]] used to represent RowId columns. + * + * - Row ID columns can be read by adding '_metadata.row_id' to the read schema + * - To write to the materialized Row ID column + * - use the materialized Row ID column name which can be obtained using + * [[getMaterializedColumnName]] + * - add [[COLUMN_METADATA]] which is part of [[RowId]] as metadata to the column + * - nulls are replaced with fresh Row IDs + */ + object RowIdMetadataStructField { + + val ROW_ID_METADATA_COL_ATTR_KEY = "__row_id_metadata_col" + + def metadata(materializedColumnName: String): types.Metadata = new MetadataBuilder() + .withMetadata( + FileSourceGeneratedMetadataStructField.metadata(RowId.ROW_ID, materializedColumnName)) + .putBoolean(ROW_ID_METADATA_COL_ATTR_KEY, value = true) + .build() + + def apply(materializedColumnName: String, nullable: Boolean = false): StructField = + StructField( + RowId.ROW_ID, + LongType, + // The Row ID field is used to read the materialized Row ID value which is nullable. The + // actual Row ID expression is created using a projection injected before the optimizer pass + // by the [[GenerateRowIDs] rule at which point the Row ID field is non-nullable. + nullable, + metadata = metadata(materializedColumnName)) + + def unapply(field: StructField): Option[StructField] = + if (isRowIdColumn(field)) Some(field) else None + + /** Return true if the column is a Row Id column. */ + def isRowIdColumn(structField: StructField): Boolean = + isValid(structField.dataType, structField.metadata) + + def isValid(dataType: DataType, metadata: types.Metadata): Boolean = { + FileSourceGeneratedMetadataStructField.isValid(dataType, metadata) && + metadata.contains(ROW_ID_METADATA_COL_ATTR_KEY) && + metadata.getBoolean(ROW_ID_METADATA_COL_ATTR_KEY) + } + } + + object RowIdMetadataAttribute { + /** Creates an attribute for writing out the materialized column name */ + def apply(materializedColumnName: String): AttributeReference = + DataTypeUtils.toAttribute(RowIdMetadataStructField(materializedColumnName)) + .withName(materializedColumnName) + + def unapply(attr: Attribute): Option[Attribute] = + if (isRowIdColumn(attr)) Some(attr) else None + + /** Return true if the column is a Row Id column. */ + def isRowIdColumn(attr: Attribute): Boolean = + RowIdMetadataStructField.isValid(attr.dataType, attr.metadata) + } + + /** + * Throw if row tracking is supported and columns in the write schema tagged as materialized row + * IDs do not reference the materialized row id column name. + */ + private[delta] def throwIfMaterializedRowIdColumnNameIsInvalid( + data: DataFrame, metadata: Metadata, protocol: Protocol, tableId: String): Unit = { + if (!RowTracking.isEnabled(protocol, metadata)) { + return + } + + val materializedColumnName = + metadata.configuration.get(MaterializedRowId.MATERIALIZED_COLUMN_NAME_PROP) + + if (materializedColumnName.isEmpty) { + // If row tracking is enabled, a missing materialized column name is a bug and we need to + // throw an error. If row tracking is only supported, we should just return, as it's fine + // for the materialized column to not be assigned. + if (RowTracking.isEnabled(protocol, metadata)) { + throw DeltaErrors.materializedRowIdMetadataMissing(tableId) + } + return + } + + toAttributes(data.schema).foreach { + case RowIdMetadataAttribute(attribute) => + if (attribute.name != materializedColumnName.get) { + throw new UnsupportedOperationException("Materialized Row IDs column name " + + s"${attribute.name} is invalid. Must be ${materializedColumnName.get}.") + } + case _ => + } + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/RowTracking.scala b/spark/src/main/scala/org/apache/spark/sql/delta/RowTracking.scala index 12d23d58dd8..0dd18c814ab 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/RowTracking.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/RowTracking.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta import org.apache.spark.sql.delta.actions.{Metadata, Protocol, TableFeatureProtocolUtils} +import org.apache.spark.sql.types.StructField /** * Utility functions for Row Tracking that are shared between Row IDs and Row Commit Versions. @@ -60,4 +61,15 @@ object RowTracking { throw DeltaErrors.convertToDeltaRowTrackingEnabledWithoutStatsCollection } } + + /** + * Returns the Row Tracking metadata fields for the file's _metadata when Row Tracking + * is enabled. + */ + def createMetadataStructFields(protocol: Protocol, metadata: Metadata, nullable: Boolean) + : Iterable[StructField] = { + RowId.createRowIdField(protocol, metadata, nullable) ++ + RowId.createBaseRowIdField(protocol, metadata) ++ + DefaultRowCommitVersion.createDefaultRowCommitVersionField(protocol, metadata) + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala b/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala index ad5efdb7581..67e5219ae2c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala @@ -106,6 +106,11 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl val normalizedData = SchemaUtils.normalizeColumnNames( deltaLog, metadata.schema, data ) + + // Validate that write columns for Row IDs have the correct name. + RowId.throwIfMaterializedRowIdColumnNameIsInvalid( + normalizedData, metadata, protocol, deltaLog.tableId) + val nullAsDefault = options.isDefined && options.get.options.contains(ColumnWithDefaultExprUtils.USE_NULL_AS_DEFAULT_DELTA_OPTION) val enforcesDefaultExprs = ColumnWithDefaultExprUtils.tableHasDefaultExpr( diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/ImplicitMetadataOperation.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/ImplicitMetadataOperation.scala index 20320440e1f..5fb7be4eb38 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/ImplicitMetadataOperation.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/ImplicitMetadataOperation.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.util.PartitionUtils import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.FileSourceGeneratedMetadataStructField import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.types.StructType @@ -53,6 +54,15 @@ trait ImplicitMetadataOperation extends DeltaLogging { } } + /** Remove all file source generated metadata columns from the schema. */ + private def dropGeneratedMetadataColumns(structType: StructType): StructType = { + val fields = structType.filter { + case FileSourceGeneratedMetadataStructField(_, _) => false + case _ => true + } + StructType(fields) + } + protected final def updateMetadata( spark: SparkSession, txn: OptimisticTransaction, @@ -65,8 +75,12 @@ trait ImplicitMetadataOperation extends DeltaLogging { // To support the new column mapping mode, we drop existing metadata on data schema // so that all the column mapping related properties can be reinitialized in // OptimisticTransaction.updateMetadata - val dataSchema = + var dataSchema = DeltaColumnMapping.dropColumnMappingMetadata(schema.asNullable) + + // File Source generated columns are not added to the stored schema. + dataSchema = dropGeneratedMetadataColumns(dataSchema) + val mergedSchema = mergeSchema(txn, dataSchema, isOverwriteMode, canOverwriteSchema) val normalizedPartitionCols = normalizePartitionColumns(spark, partitionColumns, dataSchema) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index 7d13ebd488e..c8dd66af470 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaColumnMappingMode, DeltaErrors, DeltaLog, GeneratedColumn, NoMapping, TypeWidening} +import org.apache.spark.sql.delta.RowId import org.apache.spark.sql.delta.actions.Protocol import org.apache.spark.sql.delta.commands.cdc.CDCReader import org.apache.spark.sql.delta.metering.DeltaLogging @@ -320,6 +321,9 @@ def normalizeColumnNamesInDataType( // in the table schema. case None if field.name == CDCReader.CDC_TYPE_COLUMN_NAME || field.name == CDCReader.CDC_PARTITION_COL => (field.name, None) + // Consider Row Id columns internal if Row Ids are enabled. + case None if RowId.RowIdMetadataStructField.isRowIdColumn(field) => + (field.name, None) case None => throw DeltaErrors.cannotResolveColumn(field.name, baseSchema) } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaParquetFileFormatSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaParquetFileFormatSuite.scala index 1975aaa0869..6a90293d223 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaParquetFileFormatSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaParquetFileFormatSuite.scala @@ -90,6 +90,7 @@ class DeltaParquetFileFormatSuite extends QueryTest val deltaParquetFormat = new DeltaParquetFileFormat( deltaLog.snapshot.protocol, metadata, + nullableRowTrackingFields = false, isSplittable = false, disablePushDowns = true, Some(tablePath), diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/GenerateRowIDsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/GenerateRowIDsSuite.scala new file mode 100644 index 00000000000..781224267cf --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/GenerateRowIDsSuite.scala @@ -0,0 +1,172 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rowid + +import org.apache.spark.sql.delta.DeltaTestUtils.BOOLEAN_DOMAIN +import org.apache.spark.sql.delta.RowId + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, Coalesce, EqualTo, Expression, FileSourceMetadataAttribute, GetStructField, MetadataAttributeWithLogicalName} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, Project} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.types.StructType + +/** + * This test suite checks the optimized logical plans produced after applying the [[GenerateRowIDs]] + * rule. It ensures that the rule is correctly applied to all Delta scans in different scenarios and + * that the optimizer is able to remove redundant expressions or nodes when possible. + */ +class GenerateRowIDsSuite extends QueryTest with RowIdTestUtils { + protected val testTable: String = "generateRowIDsTestTable" + + override def beforeAll(): Unit = { + super.beforeAll() + withRowTrackingEnabled(enabled = true) { + spark.range(start = 0, end = 20) + .toDF("id") + .write + .format("delta") + .saveAsTable(testTable) + } + } + + override def afterAll(): Unit = { + sql(s"DROP TABLE IF EXISTS $testTable") + super.afterAll() + } + + /** + * Test runner checking that the optimized plan for the given dataframe matches the expected plan. + * The expected plan is defined as a partial function `check`, e.g.: + * check = { + * case Project(_, LogicalRelation) => // Do additional checks + * } + * + * Note: Pass `df` by name to avoid evaluating anything before test setup. + */ + protected def testRowIdPlan( + testName: String, df: => DataFrame, rowTrackingEnabled: Boolean = true)( + check: PartialFunction[LogicalPlan, Unit]): Unit = { + test(testName) { + withRowTrackingEnabled(enabled = rowTrackingEnabled) { + check.applyOrElse(df.queryExecution.optimizedPlan, { plan: LogicalPlan => + fail(s"Unexpected optimized plan: $plan") + }) + } + } + } + + /** + * Checks that the given expression corresponds to the expression used to generate Row IDs: + * coalesce(_metadata.row_id, _metadata.base_row_id + _metadata.row_index). + */ + protected def checkRowIdExpr(expr: Expression): Unit = { + expr match { + case Coalesce( + Seq( + GetStructField(FileSourceMetadataAttribute(_), _, _), + Add( + GetStructField(FileSourceMetadataAttribute(_), _, _), + GetStructField(FileSourceMetadataAttribute(_), _, _), + _))) => () + case Alias(aliasedExpr, RowId.ROW_ID) => checkRowIdExpr(aliasedExpr) + case _ => fail(s"Expression didn't match expected Row ID expression: $expr") + } + } + + /** + * Checks that a metadata column is present in `output` and that it contains the given fields and + * only these. + */ + protected def checkMetadataFieldsPresent( + output: Seq[AttributeReference], + expectedFieldNames: Seq[String]) + : Unit = { + val metadataSchema = output.collect { + case FileSourceMetadataAttribute( + MetadataAttributeWithLogicalName( + AttributeReference(_, schema: StructType, _, _), _)) => schema + } + assert(metadataSchema.nonEmpty, s"No metadata column present in output: $output") + assert(metadataSchema.head.fieldNames === expectedFieldNames, + "Unexpected metadata fields present in the metadata output.") + } + + for (rowTrackingEnabled <- BOOLEAN_DOMAIN) + testRowIdPlan(s"Regular column selected, rowTrackingEnabled: $rowTrackingEnabled", + sql(s"SELECT id FROM $testTable"), rowTrackingEnabled) { + // No projection is added when no metadata column is selected. + case lr: LogicalRelation => + assert(lr.output.map(_.name) === Seq("id"), "Scan list didn't match") + } + + for (rowTrackingEnabled <- BOOLEAN_DOMAIN) + testRowIdPlan(s"Metadata column selected, rowTrackingEnabled: $rowTrackingEnabled", + sql(s"SELECT _metadata.file_path FROM $testTable"), rowTrackingEnabled) { + // Selecting a metadata column adds a projection to unpack metadata fields (unrelated to Row + // IDs). Row IDs don't introduce an extra projection. + case Project(projectList, lr: LogicalRelation) => + assert(projectList.map(_.name) === Seq("file_path"), "Project list didn't match") + assert(lr.output.map(_.name) === Seq("id", "_metadata"), "Scan list didn't match") + checkMetadataFieldsPresent(lr.output, Seq("file_path")) + } + + testRowIdPlan("Row ID column selected", sql(s"SELECT _metadata.row_id FROM $testTable")) { + // Selecting Row IDs injects an expression to generate default Row IDs. + case Project(Seq(rowIdExpr), lr: LogicalRelation) => + assert(rowIdExpr.name == RowId.ROW_ID) + checkRowIdExpr(rowIdExpr) + assert(lr.output.map(_.name) === Seq("id", "_metadata")) + checkMetadataFieldsPresent(lr.output, Seq("row_index", "row_id", "base_row_id")) + } + + testRowIdPlan("Filter on Row ID column", + sql(s"SELECT * FROM $testTable WHERE _metadata.row_id = 5")) { + // Filtering on Row IDs injects an expression to generate default Row IDs in the filter. + case Project(projectList, Filter(EqualTo(rowIdExpr, _), lr: LogicalRelation)) => + assert(projectList.map(_.name) === Seq("id"), "Project list didn't match") + checkRowIdExpr(rowIdExpr) + assert(lr.output.map(_.name) === Seq("id", "_metadata"), "Scan list didn't match") + checkMetadataFieldsPresent(lr.output, Seq("row_index", "row_id", "base_row_id")) + } + + testRowIdPlan("Filter on Row ID in subquery", + sql(s"SELECT * FROM $testTable WHERE _metadata.row_id IN (SELECT id FROM $testTable)")) { + // Filtering on Row IDs using a subquery injects an expression to generate default Row IDs in + // the subquery. + case Project( + projectList, + Join(right: LogicalRelation, left: LogicalPlan, _, joinCond, _)) => + assert(projectList.map(_.name) === Seq("id"), "Project list didn't match") + assert(right.output.map(_.name) === Seq("id", "_metadata"), "Outer scan output didn't match") + checkMetadataFieldsPresent(right.output, Seq("row_index", "row_id", "base_row_id")) + assert(left.output.map(_.name) === Seq("id"), "Subquery scan output didn't match") + joinCond match { + case Some(EqualTo(rowIdExpr, _)) => + checkRowIdExpr(rowIdExpr) + case _ => fail(s"Subquery was transformed into a join with an unexpected condition.") + } + } + + testRowIdPlan("Rename metadata column", + sql(s"SELECT renamed_metadata FROM (SELECT _metadata AS renamed_metadata FROM $testTable)" + )) { + case Project(projectList, lr: LogicalRelation) => + assert(projectList.map(_.name) === Seq("renamed_metadata"), "Project list didn't match") + assert(lr.output.map(_.name) === Seq("id", "_metadata")) + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdSuite.scala index 12a7ebb6861..dea69f9a3ba 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdSuite.scala @@ -16,7 +16,9 @@ package org.apache.spark.sql.delta.rowid -import org.apache.spark.sql.delta.{DeltaConfigs, DeltaIllegalStateException, DeltaLog, RowId, RowTrackingFeature, Serializable, SnapshotIsolation, WriteSerializable} +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.delta.{DeltaConfigs, DeltaIllegalStateException, DeltaLog, DeltaOperations, MaterializedRowId, RowId, RowTrackingFeature, Serializable, SnapshotIsolation} import org.apache.spark.sql.delta.DeltaOperations.ManualUpdate import org.apache.spark.sql.delta.DeltaTestUtils.BOOLEAN_DOMAIN import org.apache.spark.sql.delta.RowId.RowTrackingMetadataDomain @@ -25,15 +27,22 @@ import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils.TABLE_FEATUR import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.test.DeltaTestImplicits._ import org.apache.spark.sql.delta.util.FileNames +import org.apache.parquet.column.Encoding +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, MetadataBuilder, StructField, StructType} class RowIdSuite extends QueryTest with SharedSparkSession + with ParquetTest with RowIdTestUtils { test("Enabling row IDs on existing table does not set row IDs as readable") { withRowTrackingEnabled(enabled = false) { @@ -190,7 +199,6 @@ class RowIdSuite extends QueryTest RowId.extractHighWatermark(log.update()) assert(highWatermarkWithNewData == highWatermarkWithNewDataAfterRestore) assertRowIdsDoNotOverlap(log) - } } } @@ -539,4 +547,362 @@ class RowIdSuite extends QueryTest } } } + + test("row ids cannot be read when they are disabled") { + withRowTrackingEnabled(enabled = false) { + withTempDir { dir => + spark.range(start = 0, end = 1000, step = 1, numPartitions = 10) + .write.format("delta").save(dir.getAbsolutePath) + + withAllParquetReaders { + val err = intercept[AnalysisException] { + spark.read.format("delta").load(dir.toString).select("_metadata.row_id").collect() + } + assert(err.getMessage.contains("No such struct field")) + } + } + } + } + + + // Although readers don't have any row-id specific implementation, we still check that we + // are able to read row IDs to check that we switch to a reader that supports row IDs if the + // selected reader isn't able to. + test("row ids can be read back") { + withRowTrackingEnabled(enabled = true) { + withAllParquetReaders { + assertRowIdsCanBeReadWithRowGroupSkipping(start = 50) + // Column mapping + withSQLConf(DeltaConfigs.COLUMN_MAPPING_MODE.defaultTablePropertyKey -> "name") { + assertRowIdsCanBeRead(start = 100, numRows = 100) + } + } + } + } + + test("Can read both row id and row index") { + withRowTrackingEnabled(enabled = true) { + withAllParquetReaders { + withTempDir { dir => + val start = 10 + val recordsPerFile = 5 + spark.range(start = start, end = 20, step = 1, numPartitions = 2) + .toDF("value") + .write + .format("delta") + .save(dir.getAbsolutePath) + val df1 = spark.read.format("delta").load(dir.getAbsolutePath) + .select(RowId.QUALIFIED_COLUMN_NAME, "value", "_metadata.row_index") + checkAnswer(df1, (0 until 10).map(i => Row(i, start + i, i % recordsPerFile))) + } + } + } + } + + test("Row ID metadata field has the expected type") { + withRowTrackingEnabled(enabled = true) { + withTempDir { tempDir => + spark.range(start = 0, end = 20).toDF("id") + .write.format("delta").save(tempDir.getAbsolutePath) + + val df = spark.read.format("delta").load(tempDir.getAbsolutePath) + .select(RowId.QUALIFIED_COLUMN_NAME) + + val expectedRowIdMetadata = new MetadataBuilder() + .putBoolean("__row_id_metadata_col", value = true) + .build() + + val expectedRowIdField = StructField( + RowId.ROW_ID, + LongType, + nullable = false, + metadata = expectedRowIdMetadata) + + Seq(df.schema, df.queryExecution.analyzed.schema, df.queryExecution.optimizedPlan.schema) + .foreach { schema => + assert(schema === new StructType().add(expectedRowIdField)) + } + } + } + } + + test("Row IDs can be read in subquery") { + withRowTrackingEnabled(enabled = true) { + withTempDir { tempDir => + // Generate 2 files with base Row ID 0 and 20 resp. + spark.range(start = 0, end = 20).toDF("id").repartition(1) + .write.format("delta").save(tempDir.getAbsolutePath) + spark.range(start = 20, end = 30).toDF("id").repartition(1) + .write.format("delta").mode("append").save(tempDir.getAbsolutePath) + + val rows = sql( + s""" + |SELECT * FROM delta.`${tempDir.getAbsolutePath}` + |WHERE id IN ( + | SELECT ${RowId.QUALIFIED_COLUMN_NAME} + | FROM delta.`${tempDir.getAbsolutePath}`) + """.stripMargin) + checkAnswer(rows, (0 until 30).map(Row(_))) + } + } + } + + test("Filter by Row IDs") { + withRowTrackingEnabled(enabled = true) { + withTempDir { tempDir => + spark.range(start = 100, end = 110).toDF("id") + .write.format("delta").save(tempDir.getAbsolutePath) + + val rows = spark.read.format("delta") + .load(tempDir.getAbsolutePath).filter("_metadata.row_id % 2 = 0") + + checkAnswer(rows, (100.until(end = 110, step = 2)).map(Row(_))) + } + } + } + + test("Filter by Row IDs in subquery") { + withRowTrackingEnabled(enabled = true) { + withTempDir { tempDir => + // Generate 2 files with base Row ID 0 and 20 resp. + spark.range(start = 0, end = 20).toDF("id").repartition(1) + .write.format("delta").save(tempDir.getAbsolutePath) + spark.range(start = 20, end = 30).toDF("id").repartition(1) + .write.format("delta").mode("append").save(tempDir.getAbsolutePath) + + val rows = sql( + s""" + |SELECT * FROM delta.`${tempDir.getAbsolutePath}` + |WHERE id IN ( + | SELECT id + | FROM delta.`${tempDir.getAbsolutePath}` + | WHERE ${RowId.QUALIFIED_COLUMN_NAME} % 5 = 0) + """.stripMargin) + checkAnswer(rows, Seq(Row(0), Row(5), Row(10), Row(15), Row(20), Row(25))) + } + } + } + + test("Row IDs cannot be read if the table property is not enabled") { + withRowTrackingEnabled(enabled = true) { + withAllParquetReaders { + withTable("target") { + spark.range(10).repartition(1).write.format("delta").saveAsTable("target") + var df = spark.read.table("target") + val expected = (0 until 10).map(i => Row(i, i)) + // Check that row IDs can be read while table property is enabled + checkAnswer(df.select("id", "_metadata.row_id"), expected) + + sql( + s""" + |ALTER TABLE target + |SET TBLPROPERTIES ('${DeltaConfigs.ROW_TRACKING_ENABLED.key}' = false) + |""".stripMargin) + + df = spark.read.format("delta").table("target") + val err = intercept[AnalysisException] { + checkAnswer(df.select("id", "_metadata.row_id"), expected) + } + assert(err.getMessage.contains("No such struct field")) + // can still read other columns when table property disabled + checkAnswer(df.select("id"), (0 until 10).map(Row(_))) + } + } + } + } + + test("No row-group skipping on _metadata.row_id") { + withAllParquetReaders { + withRowTrackingEnabled(enabled = true) { + withTempPath { path => + val numRows = ParquetProperties.DEFAULT_MINIMUM_RECORD_COUNT_FOR_CHECK + val materializedColName = "materialized_rowid_col" + + val df = spark.range(start = 0, end = numRows, step = 1, numPartitions = 1) + .toDF("value") + .withColumn(materializedColName, + when(col("value") < (numRows / 2), col("value")) + .otherwise(lit(null))) + writeParquetWithMinimalRowGroupSize(df, path.toString) + + sql(s"CONVERT TO DELTA parquet.`$path`") + + setRowIdMaterializedColumnName( + DeltaLog.forTable(spark, path), colName = materializedColName) + + checkFileLayout( + path, + numFiles = 1, + numRowGroupsPerFile = 1, + rowCountPerRowGroup = numRows) + + // Filter by row IDs that are not part of the materialized column. If we don't take fresh + // row IDs into account, the row group will be skipped and the test will fail. + val dfWithSkippingOnRowId = spark.read.format("delta").load(path.toString).select("value") + .where(col(RowId.QUALIFIED_COLUMN_NAME) >= (numRows / 2)) + checkAnswer(dfWithSkippingOnRowId, ((numRows / 2) until numRows).map(Row(_))) + checkScanMetrics( + dfWithSkippingOnRowId.queryExecution.executedPlan, + expectedNumOfRows = numRows) + } + } + } + } + + test("No dictionary filtering on _metadata.row_id") { + withAllParquetReaders { + withRowTrackingEnabled(enabled = true) { + withTempPath { path => + val numRows = ParquetProperties.DEFAULT_MINIMUM_RECORD_COUNT_FOR_CHECK + val materializedColName = "materialized_rowid_col" + + val df = spark.range(start = 0, end = numRows, step = 1, numPartitions = 1) + .toDF("value") + .withColumn(materializedColName, + // This will cause dictionary encoding to be used, as the column has few unique + // values. Normally this shouldn't happen with row IDs, but we want to ensure that + // we can still read row IDs correctly if dictionary encoding is used. + when(col("value") > 0, lit(1L)) + .otherwise(lit(null))) + writeParquetWithMinimalRowGroupSize(df, path.toString) + + sql(s"CONVERT TO DELTA parquet.`$path`") + + setRowIdMaterializedColumnName( + DeltaLog.forTable(spark, path), colName = materializedColName) + + checkFileLayout( + path, + numFiles = 1, + numRowGroupsPerFile = 1, + rowCountPerRowGroup = numRows) + + // We can't check directly whether dictionary filtering will take place, but we can ensure + // that the row ID column is dictionary encoded, which should mean that the + // optimization is applied. + readRowGroupsPerFile(path).flatten.foreach { block => + val rowIdColChunk = block.getColumns.asScala.find( + _.getPath.asScala.exists(_ == materializedColName)).get + assert(rowIdColChunk.getEncodings.contains(Encoding.PLAIN_DICTIONARY)) + } + + // Filter by row IDs that are not part of the materialized column. If we don't take fresh + // row IDs into account, the row group will be skipped and the test will fail. + val dfWithSkippingOnRowId = spark.read.format("delta").load(path.toString).select("value") + .where(col(RowId.QUALIFIED_COLUMN_NAME).equalTo(0)) + checkAnswer(dfWithSkippingOnRowId, Row(0)) + checkScanMetrics( + dfWithSkippingOnRowId.queryExecution.executedPlan, + expectedNumOfRows = numRows) + } + } + } + } + + test("Reading row IDs when file is split and splits are recombined") { + withSQLConf( + DeltaConfigs.ROW_TRACKING_ENABLED.defaultTablePropertyKey -> "true", + // 10 byte partition sizes + SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10B") { + withTempDir { dir => + spark.range(end = 10).repartition(1) + // Add some more random columns, leads to multiple splits being recombined into a single + // partition + .selectExpr("id", "id as id2", "id as id3", "id as id4") + .write.format("delta").save(dir.toString) + val log = DeltaLog.forTable(spark, dir) + // Make sure we would create at least two splits of a single file + val necessarySplitSizeBytes = 20 + assert(log.update().allFiles.collect().forall(_.size > necessarySplitSizeBytes)) + checkAnswer( + spark.read.format("delta").load(dir.toString).select("id", RowId.QUALIFIED_COLUMN_NAME), + (0 until 10).map(i => Row(i, i))) + } + } + } + + protected def assertRowIdsCanBeRead(start: Int, numRows: Int): Unit = { + withTempDir { dir => + spark.range(start, end = start + numRows, step = 1, numPartitions = 3) + .toDF("value") + .write + .format("delta") + .save(dir.getAbsolutePath) + + val df1 = spark.read.format("delta").load(dir.getAbsolutePath) + .select(RowId.QUALIFIED_COLUMN_NAME, "value") + checkAnswer(df1, (0L until numRows).map(i => Row(i, start + i))) + + val df2 = spark.read.format("delta").load(dir.getAbsolutePath) + .select("value", RowId.QUALIFIED_COLUMN_NAME) + checkAnswer(df2, (0L until numRows).map(i => Row(start + i, i))) + } + } + + protected def writeParquetWithMinimalRowGroupSize(df: DataFrame, path: String): Unit = { + df.write + .format("parquet") + // The minimum row count in a row group is + // `ParquetProperties.DEFAULT_MINIMUM_RECORD_COUNT_FOR_CHECK`, if we specify a + // block size that can't accommodate the minimum row count, we'll write exactly + // the minimum row count per row group. + .option(ParquetOutputFormat.BLOCK_SIZE, 0) + .save(path) + } + + protected def setRowIdMaterializedColumnName(log: DeltaLog, colName: String): Unit = { + val metadata = log.update().metadata + val configWithUpdatedRowIdColName = metadata.configuration + ( + MaterializedRowId.MATERIALIZED_COLUMN_NAME_PROP -> colName) + // We need to remove the column from the schema as we are not allowed to have the + // materialized row ID column be part of the schema. + val schemaFieldsWithoutRowIdCol = metadata.schema.filterNot(_.name == colName) + val updatedMetadata = metadata.copy( + configuration = configWithUpdatedRowIdColName, + schemaString = metadata.schema.copy(fields = schemaFieldsWithoutRowIdCol.toArray).json) + log.startTransaction().commit(Seq(updatedMetadata), DeltaOperations.ManualUpdate) + } + + protected def checkScanMetrics(plan: SparkPlan, expectedNumOfRows: Long): Unit = { + var numOutputRows = 0L + plan.foreach { + case f: FileSourceScanExec => + numOutputRows += f.metrics("numOutputRows").value + case _ => // Not a scan node, do nothing. + } + assert(expectedNumOfRows === numOutputRows) + } + + private def assertRowIdsCanBeReadWithRowGroupSkipping(start: Int): Unit = { + val rowGroupRowCount = ParquetProperties.DEFAULT_MINIMUM_RECORD_COUNT_FOR_CHECK + // write at least two row groups + val numRows = rowGroupRowCount * 2 + withTempPath { path => + val df = spark.range(start, end = start + numRows, step = 1, numPartitions = 1).toDF("value") + writeParquetWithMinimalRowGroupSize(df, path.toString) + sql(s"CONVERT TO DELTA parquet.`$path`") + + import testImplicits._ + checkFileLayout( + path, + numFiles = 1, + numRowGroupsPerFile = 2, + rowCountPerRowGroup = rowGroupRowCount) + + val rowGroups = readRowGroupsPerFile(path).head + val minValueSecondRowGroup = rowGroups(1).getColumns.get(0).getStatistics.genericGetMin() + + val df1 = spark.read.format("delta").load(path.getAbsolutePath) + .filter($"value" >= minValueSecondRowGroup) + .select(RowId.QUALIFIED_COLUMN_NAME, "value") + checkAnswer(df1, (rowGroupRowCount until numRows).map(i => Row(i, start + i))) + checkScanMetrics(df1.queryExecution.executedPlan, expectedNumOfRows = rowGroupRowCount) + + val df2 = spark.read.format("delta").load(path.getAbsolutePath) + .filter($"value" >= minValueSecondRowGroup) + .select("value", RowId.QUALIFIED_COLUMN_NAME) + checkAnswer(df2, (rowGroupRowCount until numRows).map(i => Row(start + i, i))) + checkScanMetrics(df2.queryExecution.executedPlan, expectedNumOfRows = rowGroupRowCount) + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala index 6d15b1b7f76..c1cc0565a42 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdTestUtils.scala @@ -16,14 +16,21 @@ package org.apache.spark.sql.delta.rowid +import java.io.File + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.delta.{DeltaLog, MaterializedRowId, RowId} import org.apache.spark.sql.delta.actions.AddFile import org.apache.spark.sql.delta.rowtracking.RowTrackingTestUtils import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.metadata.BlockMetaData import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest -trait RowIdTestUtils extends RowTrackingTestUtils with DeltaSQLCommandTest { +trait RowIdTestUtils extends RowTrackingTestUtils with DeltaSQLCommandTest with ParquetTest { val QUALIFIED_BASE_ROW_ID_COLUMN_NAME = s"${FileFormat.METADATA_NAME}.${RowId.BASE_ROW_ID}" protected def getRowIdRangeInclusive(f: AddFile): (Long, Long) = { @@ -93,4 +100,28 @@ trait RowIdTestUtils extends RowTrackingTestUtils with DeltaSQLCommandTest { def extractMaterializedRowIdColumnName(log: DeltaLog): Option[String] = { log.update().metadata.configuration.get(MaterializedRowId.MATERIALIZED_COLUMN_NAME_PROP) } + + protected def readRowGroupsPerFile(dir: File): Seq[Seq[BlockMetaData]] = { + assert(dir.isDirectory) + readAllFootersWithoutSummaryFiles( + // scalastyle:off deltahadoopconfiguration + new Path(dir.getAbsolutePath), spark.sessionState.newHadoopConf()) + // scalastyle:on deltahadoopconfiguration + .map(_.getParquetMetadata.getBlocks.asScala.toSeq) + } + + protected def checkFileLayout( + dir: File, + numFiles: Int, + numRowGroupsPerFile: Int, + rowCountPerRowGroup: Int): Unit = { + val rowGroupsPerFile = readRowGroupsPerFile(dir) + assert(numFiles === rowGroupsPerFile.size) + for (rowGroups <- rowGroupsPerFile) { + assert(numRowGroupsPerFile === rowGroups.size) + for (rowGroup <- rowGroups) { + assert(rowCountPerRowGroup === rowGroup.getRowCount) + } + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala index 28984d1cca0..3b91e18c1e4 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala @@ -23,6 +23,7 @@ import java.util.regex.Pattern import scala.annotation.tailrec import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaLog, DeltaTestUtils} +import org.apache.spark.sql.delta.RowId import org.apache.spark.sql.delta.commands.cdc.CDCReader import org.apache.spark.sql.delta.schema.SchemaMergingUtils._ import org.apache.spark.sql.delta.sources.DeltaSourceUtils.GENERATION_EXPRESSION_METADATA_KEY @@ -2015,6 +2016,28 @@ class SchemaUtilsSuite extends QueryTest assert(normalized.schema === tableSchema) } + test("normalize column names - can normalize row id column") { + withTable("src") { + spark.range(3).toDF("id").write + .format("delta") + .mode("overwrite") + .option("delta.enableRowTracking", "true") + .saveAsTable("src") + + val df = spark.read.format("delta").table("src") + .select( + col("*"), + col("_metadata.row_id").as("row_id") + ) + .withMetadata("row_id", RowId.RowIdMetadataStructField.metadata("name")) + + val tableSchema = new StructType().add("id", LongType) + val normalized = + normalizeColumnNames(deltaLog = null, tableSchema, df) + assert(normalized.schema.fieldNames === Seq("id", "row_id")) + } + } + test("normalize column names - can normalize CDC type column") { val df = Seq((1, 2, 3, 4)).toDF("Abc", "def", "gHi", CDCReader.CDC_TYPE_COLUMN_NAME) val tableSchema = new StructType()