Skip to content

Commit

Permalink
[Spark] Add read support for baseRowID (delta-io#2779)
Browse files Browse the repository at this point in the history
<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [X] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description
Adding the `base_row_id` field to the `_metadata` column for Delta
tables. This field contains the value in the `baseRowId` field of the
`AddFile` action for the file, allowing us to read the `base_row_id`
from the file metadata after it is stored.
<!--
- Describe what this PR changes.
- Describe why we need the change.
 
If this PR resolves an issue be sure to include "Resolves #XXX" to
correctly link and close the issue upon merge.
-->

## How was this patch tested?
Added UTs.
<!--
If tests were added, say they were added here. Please make sure to test
the changes thoroughly including negative and positive cases if
possible.
If the changes were tested in any way other than unit tests, please
clarify how you tested step by step (ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future).
If the changes were not tested, please explain why.
-->

## Does this PR introduce _any_ user-facing changes?
No.
<!--
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to
show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change
compared to the released Delta Lake versions or within the unreleased
branches such as master.
If no, write 'No'.
-->
  • Loading branch information
longvu-db authored Mar 25, 2024
1 parent 4b043b1 commit 283ac02
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ case class DeltaParquetFileFormat(
// Parquet reader in Spark has a bug where a file containing 2b+ rows in a single rowgroup
// causes it to run out of the `Integer` range (TODO: Create a SPARK issue)
// For Delta Parquet readers don't expose the row_index field as a metadata field.
super.metadataSchemaFields.filter(field => field != ParquetFileFormat.ROW_INDEX_FIELD)
super.metadataSchemaFields.filter(field => field != ParquetFileFormat.ROW_INDEX_FIELD) ++
RowId.createBaseRowIdField(protocol, metadata)
}

override def prepareWrite(
Expand All @@ -224,6 +225,17 @@ case class DeltaParquetFileFormat(
factory
}

override def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] = {
val extractBaseRowId: PartitionedFile => Any = { file =>
file.otherConstantMetadataColumnValues.getOrElse(RowId.BASE_ROW_ID, {
throw new IllegalStateException(
s"Missing ${RowId.BASE_ROW_ID} value for file '${file.filePath}'")
})
}
super.fileConstantMetadataExtractors
.updated(RowId.BASE_ROW_ID, extractBaseRowId)
}

def copyWithDVInfo(
tablePath: String,
broadcastDvMap: Broadcast[Map[URI, DeletionVectorDescriptorWithFilterType]],
Expand Down
49 changes: 49 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/delta/RowId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.actions.{Action, AddFile, DomainMetadata, Metadata, Protocol}
import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils.propertyKey
import org.apache.spark.sql.util.ScalaExtensions._

import org.apache.spark.sql.catalyst.expressions.FileSourceConstantMetadataStructField
import org.apache.spark.sql.types
import org.apache.spark.sql.types.{DataType, LongType, MetadataBuilder, StructField}

/**
* Collection of helpers to handle Row IDs.
Expand Down Expand Up @@ -108,4 +113,48 @@ object RowId {
*/
private[delta] def extractHighWatermark(snapshot: Snapshot): Option[Long] =
RowTrackingMetadataDomain.fromSnapshot(snapshot).map(_.rowIdHighWaterMark)

/** Base Row ID column name */
val BASE_ROW_ID = "base_row_id"

/*
* A specialization of [[FileSourceConstantMetadataStructField]] used to represent base RowId
* columns.
*/
object BaseRowIdMetadataStructField {
private val BASE_ROW_ID_METADATA_COL_ATTR_KEY = s"__base_row_id_metadata_col"

def metadata: types.Metadata = new MetadataBuilder()
.withMetadata(FileSourceConstantMetadataStructField.metadata(BASE_ROW_ID))
.putBoolean(BASE_ROW_ID_METADATA_COL_ATTR_KEY, value = true)
.build()

def apply(): StructField =
StructField(
BASE_ROW_ID,
LongType,
nullable = false,
metadata = metadata)

def unapply(field: StructField): Option[StructField] =
Some(field).filter(isBaseRowIdColumn)

/** Return true if the column is a base Row ID column. */
def isBaseRowIdColumn(structField: StructField): Boolean =
isValid(structField.dataType, structField.metadata)

def isValid(dataType: DataType, metadata: types.Metadata): Boolean = {
FileSourceConstantMetadataStructField.isValid(dataType, metadata) &&
metadata.contains(BASE_ROW_ID_METADATA_COL_ATTR_KEY) &&
metadata.getBoolean(BASE_ROW_ID_METADATA_COL_ATTR_KEY)
}
}

/**
* The field readers can use to access the base row id column.
*/
def createBaseRowIdField(protocol: Protocol, metadata: Metadata): Option[StructField] =
Option.when(RowId.isEnabled(protocol, metadata)) {
BaseRowIdMetadataStructField()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package org.apache.spark.sql.delta.files
import java.net.URI
import java.util.Objects

import scala.collection.mutable
import org.apache.spark.sql.delta.RowIndexFilterType
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaErrors, DeltaLog, NoMapping, Snapshot, SnapshotDescriptor}
import org.apache.spark.sql.delta.RowId
import org.apache.spark.sql.delta.actions.{AddFile, Metadata, Protocol}
import org.apache.spark.sql.delta.implicits._
import org.apache.spark.sql.delta.schema.SchemaUtils
Expand Down Expand Up @@ -107,23 +109,29 @@ abstract class TahoeFileIndex(
matchingFiles(partitionFilters, dataFilters).groupBy(_.partitionValues)
}

/**
* Generates a FileStatusWithMetadata using data extracted from a given AddFile.
*/
def fileStatusWithMetadataFromAddFile(addFile: AddFile): FileStatusWithMetadata = {
val fs = new FileStatus(
/* length */ addFile.size,
/* isDir */ false,
/* blockReplication */ 0,
/* blockSize */ 1,
/* modificationTime */ addFile.modificationTime,
/* path */ absolutePath(addFile.path))
val metadata = mutable.Map.empty[String, Any]
addFile.baseRowId.foreach(baseRowId => metadata.put(RowId.BASE_ROW_ID, baseRowId))

FileStatusWithMetadata(fs, metadata.toMap)
}

def makePartitionDirectories(
partitionValuesToFiles: Seq[(InternalRow, Seq[AddFile])]): Seq[PartitionDirectory] = {
val timeZone = spark.sessionState.conf.sessionLocalTimeZone
partitionValuesToFiles.map {
case (partitionValues, files) =>

val fileStatuses = files.map { f =>
new FileStatus(
/* length */ f.size,
/* isDir */ false,
/* blockReplication */ 0,
/* blockSize */ 1,
/* modificationTime */ f.modificationTime,
absolutePath(f.path))
}.toArray

val fileStatuses = files.map(f => fileStatusWithMetadataFromAddFile(f)).toArray
PartitionDirectory(partitionValues, fileStatuses)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class TahoeRemoveFileIndex(
modificationTime = 0,
dataChange = r.dataChange,
tags = r.tags,
deletionVector = r.deletionVector
deletionVector = r.deletionVector,
baseRowId = r.baseRowId
)
}
}
Expand Down
159 changes: 159 additions & 0 deletions spark/src/test/scala/org/apache/spark/sql/delta/rowid/RowIdSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{LongType, MetadataBuilder, StructField, StructType}

class RowIdSuite extends QueryTest
with SharedSparkSession
Expand Down Expand Up @@ -380,4 +381,162 @@ class RowIdSuite extends QueryTest
}
}
}

test("Base Row ID metadata field has the expected type") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
spark.range(start = 0, end = 20).toDF("id")
.write.format("delta").save(tempDir.getAbsolutePath)

val df = spark.read.format("delta").load(tempDir.getAbsolutePath)
.select(QUALIFIED_BASE_ROW_ID_COLUMN_NAME)

val expectedBaseRowIdMetadata = new MetadataBuilder()
.putBoolean("__base_row_id_metadata_col", value = true)
.build()

val expectedBaseRowIdField = StructField(
RowId.BASE_ROW_ID,
LongType,
nullable = false,
metadata = expectedBaseRowIdMetadata)

Seq(df.schema, df.queryExecution.analyzed.schema, df.queryExecution.optimizedPlan.schema)
.foreach { schema =>
assert(schema === new StructType().add(expectedBaseRowIdField))
}
}
}
}

test("Base Row IDs can be read with conflicting metadata column name") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
// Generate 2 files with base Row ID 0 and 20 resp.
spark.range(start = 0, end = 20).toDF("_metadata").repartition(1)
.write.format("delta").save(tempDir.getAbsolutePath)
spark.range(start = 20, end = 30).toDF("_metadata").repartition(1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)

val df = spark.read.format("delta").load(tempDir.getAbsolutePath).select("_metadata")

val dfWithConflict = df
.select(
col("_metadata"),
df.metadataColumn("_metadata")
.getField({RowId.BASE_ROW_ID})
.as("real_base_row_id"))
.where("real_base_row_id % 5 = 0")

checkAnswer(dfWithConflict,
(0 until 20).map(Row(_, 0)) ++
(20 until 30).map(Row(_, 20)))
}
}
}

test("Base Row IDs can be read through the Scala syntax") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
// Generate 2 files with base Row ID 0 and 20 resp.
spark.range(start = 0, end = 20).toDF("id").repartition(1)
.write.format("delta").save(tempDir.getAbsolutePath)
spark.range(start = 20, end = 30).toDF("id").repartition(1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)

val df = spark.read.format("delta").load(tempDir.getAbsolutePath)
.select("id", QUALIFIED_BASE_ROW_ID_COLUMN_NAME)

checkAnswer(df,
(0 until 20).map(Row(_, 0)) ++
(20 until 30).map(Row(_, 20)))
}
}
}

test("Base Row IDs can be read through the SQL syntax") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
// Generate 2 files with base Row ID 0 and 20 resp.
spark.range(start = 0, end = 20).toDF("id").repartition(1)
.write.format("delta").save(tempDir.getAbsolutePath)
spark.range(start = 20, end = 30).toDF("id").repartition(1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)

val rows = sql(
s"""
|SELECT id, $QUALIFIED_BASE_ROW_ID_COLUMN_NAME FROM delta.`${tempDir.getAbsolutePath}`
""".stripMargin
)

checkAnswer(rows,
(0 until 20).map(Row(_, 0)) ++
(20 until 30).map(Row(_, 20)))
}
}
}

test("Filter by base Row IDs") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
// Generate 3 files with base Row ID 0, 10 and 20 resp.
spark.range(start = 0, end = 10).toDF("id").repartition(1)
.write.format("delta").save(tempDir.getAbsolutePath)
spark.range(start = 10, end = 20).toDF("id").repartition(1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)
spark.range(start = 20, end = 30).toDF("id").repartition(1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)

val df = spark.read.format("delta").load(tempDir.getAbsolutePath)
.where(col(QUALIFIED_BASE_ROW_ID_COLUMN_NAME) === 10)

checkAnswer(df, (10 until 20).map(Row(_)))
}
}
}

test("Base Row IDs can be read in subquery") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
// Generate 2 files with base Row ID 0 and 20 resp.
spark.range(start = 0, end = 20).toDF("id").repartition(1)
.write.format("delta").save(tempDir.getAbsolutePath)
spark.range(start = 20, end = 30).toDF("id").repartition(1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)

val rows = sql(
s"""
|SELECT * FROM delta.`${tempDir.getAbsolutePath}`
|WHERE id IN (
| SELECT $QUALIFIED_BASE_ROW_ID_COLUMN_NAME
| FROM delta.`${tempDir.getAbsolutePath}`)
""".stripMargin)

checkAnswer(rows, Seq(Row(0), Row(20)))
}
}
}

test("Filter by base Row IDs in subquery") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
// Generate 2 files with base Row ID 0 and 20 resp.
spark.range(start = 0, end = 20).toDF("id").repartition(1)
.write.format("delta").save(tempDir.getAbsolutePath)
spark.range(start = 20, end = 30).toDF("id").repartition(1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)

val rows = sql(
s"""
|SELECT * FROM delta.`${tempDir.getAbsolutePath}`
|WHERE id IN (
| SELECT id
| FROM delta.`${tempDir.getAbsolutePath}`
| WHERE $QUALIFIED_BASE_ROW_ID_COLUMN_NAME = 20)
""".stripMargin)

checkAnswer(rows, (20 until 30).map(Row(_)))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ import org.apache.spark.sql.delta.actions.AddFile
import org.apache.spark.sql.delta.rowtracking.RowTrackingTestUtils
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest

import org.apache.spark.sql.execution.datasources.FileFormat

trait RowIdTestUtils extends RowTrackingTestUtils with DeltaSQLCommandTest {
val QUALIFIED_BASE_ROW_ID_COLUMN_NAME = s"${FileFormat.METADATA_NAME}.${RowId.BASE_ROW_ID}"

protected def getRowIdRangeInclusive(f: AddFile): (Long, Long) = {
val min = f.baseRowId.get
val max = min + f.numPhysicalRecords.get - 1L
Expand Down

0 comments on commit 283ac02

Please sign in to comment.