Skip to content

Commit

Permalink
1385 - Collect statistics by default in ConvertToDelta & Update SQL API
Browse files Browse the repository at this point in the history
## Description

Resolves #1385

This PR gives the option to collect column statistics per file when doing ConvertToDelta operation. The current behaviour should stay the same (meaning that statistics should not be collected on convert).

Current unit tests + added a couple of unit tests to cover this change

## Does this PR introduce _any_ user-facing changes?

Yes. Previously ConvertToDelta command was without the ability to collect statistics. Now, the convert function is overloaded with additional boolean parameter (with false as default) to enable collect statistics.
When this parameter is set to true, after the creation of the table, delta will compute all the files statistics (using legacy functionality and create another COMPUTE STATS commit to the log)

Closes #1401

Co-authored-by: Amir Mor <amir.mor@appsflyer.com>
Signed-off-by: Scott Sandre <scott.sandre@databricks.com>
Signed-off-by: Venki Korukanti <venki.korukanti@databricks.com>
GitOrigin-RevId: 095b503053d2749ae1fd50e49348c3157b19366d
  • Loading branch information
2 people authored and vkorukanti committed Nov 30, 2022
1 parent 406e225 commit a5fcec4
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 25 deletions.
6 changes: 4 additions & 2 deletions core/src/main/antlr4/io/delta/sql/parser/DeltaSqlBase.g4
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ statement
| (DESC | DESCRIBE) HISTORY (path=STRING | table=qualifiedName)
(LIMIT limit=INTEGER_VALUE)? #describeDeltaHistory
| CONVERT TO DELTA table=qualifiedName
(PARTITIONED BY '(' colTypeList ')')? #convert
(NO STATISTICS)? (PARTITIONED BY '(' colTypeList ')')? #convert
| RESTORE TABLE? table=qualifiedName TO?
clause=temporalClause #restore
| ALTER TABLE table=qualifiedName ADD CONSTRAINT name=identifier
Expand Down Expand Up @@ -166,7 +166,7 @@ nonReserved
| GENERATE | FOR | TABLE | CHECK | EXISTS | OPTIMIZE
| RESTORE | AS | OF
| ZORDER | LEFT_PAREN | RIGHT_PAREN
| SHOW | COLUMNS | IN | FROM
| SHOW | COLUMNS | IN | FROM | NO | STATISTICS
;

// Define how the keywords above should appear in a user's SQL statement.
Expand Down Expand Up @@ -197,6 +197,7 @@ IN: 'IN';
LEFT_PAREN: '(';
LIMIT: 'LIMIT';
MINUS: '-';
NO: 'NO';
NOT: 'NOT' | '!';
NULL: 'NULL';
OF: 'OF';
Expand All @@ -216,6 +217,7 @@ VACUUM: 'VACUUM';
VERSION: 'VERSION';
WHERE: 'WHERE';
ZORDER: 'ZORDER';
STATISTICS: 'STATISTICS';

// Multi-character operator tokens need to be defined even though we don't explicitly reference
// them so that they can be recognized as single tokens when parsing. If we split them up and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class DeltaSqlAstBuilder extends DeltaSqlBaseBaseVisitor[AnyRef] {
ConvertToDeltaCommand(
visitTableIdentifier(ctx.table),
Option(ctx.colTypeList).map(colTypeList => StructType(visitColTypeList(colTypeList))),
None)
ctx.STATISTICS() == null, None)
}

override def visitRestore(ctx: RestoreContext): LogicalPlan = withOrigin(ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ trait DeltaConvertBase {
tableIdentifier: TableIdentifier,
partitionSchema: Option[StructType],
deltaPath: Option[String]): DeltaTable = {
val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath)
val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, collectStats = true,
deltaPath)
cvt.run(spark)
if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) {
DeltaTable.forName(spark, tableIdentifier.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.actions.{AddFile, Metadata}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.commands.VacuumCommand.{generateCandidateFileMap, getTouchedFile}
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.schema.SchemaMergingUtils
import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf}
Expand Down Expand Up @@ -66,13 +67,17 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
*
* @param tableIdentifier the target parquet table.
* @param partitionSchema the partition schema of the table, required when table is partitioned.
* @param collectStats Should collect column stats per file on convert.
* @param deltaPath if provided, the delta log will be written to this location.
*/
abstract class ConvertToDeltaCommandBase(
tableIdentifier: TableIdentifier,
partitionSchema: Option[StructType],
collectStats: Boolean,
deltaPath: Option[String]) extends LeafRunnableCommand with DeltaCommand {

protected lazy val statsEnabled: Boolean = conf.getConf(DeltaSQLConf.DELTA_COLLECT_STATS)

protected lazy val icebergEnabled: Boolean =
conf.getConf(DeltaSQLConf.DELTA_CONVERT_ICEBERG_ENABLED)

Expand Down Expand Up @@ -281,12 +286,26 @@ abstract class ConvertToDeltaCommandBase(
partitionSchema: StructType,
txn: OptimisticTransaction,
fs: FileSystem): Iterator[AddFile] = {
val initialSnapshot = new InitialSnapshot(txn.deltaLog.dataPath, txn.deltaLog, txn.metadata)
val shouldCollectStats = collectStats && statsEnabled
val statsBatchSize = conf.getConf(DeltaSQLConf.DELTA_IMPORT_BATCH_SIZE_STATS_COLLECTION)
var numFiles = 0L
manifest.getFiles.grouped(statsBatchSize).flatMap { batch =>
val adds = batch.map(
ConvertToDeltaCommand.createAddFile(
_, txn.deltaLog.dataPath, fs, conf, Some(partitionSchema), deltaPath.isDefined))
adds.toIterator
if (shouldCollectStats) {
logInfo(s"Collecting stats for a batch of ${batch.size} files; " +
s"finished $numFiles so far")
numFiles += statsBatchSize
ConvertToDeltaCommand.computeStats(txn.deltaLog, initialSnapshot, adds)
} else if (collectStats) {
logWarning(s"collectStats is set to true but ${DeltaSQLConf.DELTA_COLLECT_STATS.key}" +
s" is false. Skip statistics collection")
adds.toIterator
} else {
adds.toIterator
}
}
}

Expand Down Expand Up @@ -388,7 +407,7 @@ abstract class ConvertToDeltaCommandBase(
DeltaOperations.Convert(
numFilesConverted,
partitionSchema.map(_.fieldNames.toSeq).getOrElse(Nil),
collectStats = false,
collectStats = collectStats && statsEnabled,
convertProperties.catalogTable.map(t => t.identifier.toString),
sourceFormat = Some(sourceFormat))
}
Expand All @@ -412,8 +431,9 @@ abstract class ConvertToDeltaCommandBase(
case class ConvertToDeltaCommand(
tableIdentifier: TableIdentifier,
partitionSchema: Option[StructType],
collectStats: Boolean,
deltaPath: Option[String])
extends ConvertToDeltaCommandBase(tableIdentifier, partitionSchema, deltaPath)
extends ConvertToDeltaCommandBase(tableIdentifier, partitionSchema, collectStats, deltaPath)

/**
* An interface for the file to be included during conversion.
Expand Down Expand Up @@ -868,6 +888,21 @@ trait ConvertToDeltaCommandUtils extends DeltaLogging {
DeltaFileOperations.defaultHiddenFileFilter(fileName) && !fileName.contains("=")
}

def computeStats(
deltaLog: DeltaLog,
snapshot: Snapshot,
addFiles: Seq[AddFile]): Iterator[AddFile] = {
import org.apache.spark.sql.functions._
val filesWithStats = deltaLog.createDataFrame(snapshot, addFiles)
.groupBy(input_file_name()).agg(to_json(snapshot.statsCollector))

val pathToAddFileMap = generateCandidateFileMap(deltaLog.dataPath, addFiles)
filesWithStats.collect().iterator.map { row =>
val addFile = getTouchedFile(deltaLog.dataPath, row.getString(0), pathToAddFileMap)
addFile.copy(stats = row.getString(1))
}
}

def getParquetTable(
spark: SparkSession,
targetDir: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ trait DeltaCommand extends DeltaLogging {
* rewrite files such as delete, merge, update. We expect file names to be unique, because
* each file contains a UUID.
*/
protected def generateCandidateFileMap(
def generateCandidateFileMap(
basePath: Path,
candidateFiles: Seq[AddFile]): Map[String, AddFile] = {
val nameToAddFileMap = candidateFiles.map(add =>
Expand Down Expand Up @@ -154,7 +154,7 @@ trait DeltaCommand extends DeltaLogging {
* @param filePath The path to a file. Can be either absolute or relative
* @param nameToAddFileMap Map generated through `generateCandidateFileMap()`
*/
protected def getTouchedFile(
def getTouchedFile(
basePath: Path,
filePath: String,
nameToAddFileMap: Map[String, AddFile]): AddFile = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,48 @@

package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest

import org.apache.spark.sql.functions.{col, from_json}

trait ConvertToDeltaSQLSuiteBase extends ConvertToDeltaSuiteBaseCommons
with DeltaSQLCommandTest {
override protected def convertToDelta(
identifier: String,
partitionSchema: Option[String] = None): Unit = {
partitionSchema: Option[String] = None, collectStats: Boolean = true): Unit = {
if (partitionSchema.isEmpty) {
sql(s"convert to delta $identifier")
sql(s"convert to delta $identifier ${collectStatisticsStringOption(collectStats)}")
} else {
val stringSchema = partitionSchema.get
sql(s"convert to delta $identifier partitioned by ($stringSchema) ")
sql(s"convert to delta $identifier ${collectStatisticsStringOption(collectStats)}" +
s" partitioned by ($stringSchema)")
}
}

// TODO: Move to ConvertToDeltaSuiteBaseCommons when DeltaTable API contains collectStats option
test("convert with collectStats set to false") {
withTempDir { dir =>
withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "true") {

val tempDir = dir.getCanonicalPath
writeFiles(tempDir, simpleDF)
convertToDelta(s"parquet.`$tempDir`", collectStats = false)
val deltaLog = DeltaLog.forTable(spark, tempDir)
val history = io.delta.tables.DeltaTable.forPath(tempDir).history()
checkAnswer(
spark.read.format("delta").load(tempDir),
simpleDF
)
assert(history.count == 1)
val statsDf = deltaLog.unsafeVolatileSnapshot.allFiles
.select(from_json(col("stats"), deltaLog.unsafeVolatileSnapshot.statsSchema)
.as("stats")).select("stats.*")
assert(statsDf.filter(col("numRecords").isNotNull).count == 0)
}
}
}

}

class ConvertToDeltaSQLSuite extends ConvertToDeltaSQLSuiteBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.types.StructType
class ConvertToDeltaScalaSuite extends ConvertToDeltaSuiteBase {
override protected def convertToDelta(
identifier: String,
partitionSchema: Option[String] = None): Unit = {
partitionSchema: Option[String] = None, collectStats: Boolean = true): Unit = {
if (partitionSchema.isDefined) {
io.delta.tables.DeltaTable.convertToDelta(
spark,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.io.{File, FileNotFoundException}
import org.apache.spark.sql.delta.files.TahoeLogFileIndex
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
import org.apache.spark.sql.delta.test.DeltaTestImplicits._
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
Expand All @@ -41,11 +40,15 @@ import org.apache.spark.util.Utils
*/
trait ConvertToDeltaTestUtils extends QueryTest { self: SQLTestUtils =>

protected def collectStatisticsStringOption(collectStats: Boolean): String = Option(collectStats)
.filterNot(identity).map(_ => "NO STATISTICS").getOrElse("")

protected def simpleDF = spark.range(100)
.withColumn("key1", col("id") % 2)
.withColumn("key2", col("id") % 3 cast "String")

protected def convertToDelta(identifier: String, partitionSchema: Option[String] = None): Unit
protected def convertToDelta(identifier: String, partitionSchema: Option[String] = None,
collectStats: Boolean = true): Unit

protected val blockNonDeltaMsg = "A transaction log for Delta was found at"
protected val parquetOnlyMsg = "CONVERT TO DELTA only supports parquet tables"
Expand Down Expand Up @@ -98,6 +101,47 @@ trait ConvertToDeltaSuiteBase extends ConvertToDeltaSuiteBaseCommons
}
}

test("convert with collectStats true") {
withTempDir { dir =>
val tempDir = dir.getCanonicalPath
writeFiles(tempDir, simpleDF)
convertToDelta(s"parquet.`$tempDir`", collectStats = true)
val deltaLog = DeltaLog.forTable(spark, tempDir)
val history = io.delta.tables.DeltaTable.forPath(tempDir).history()
checkAnswer(
spark.read.format("delta").load(tempDir),
simpleDF
)
assert(history.count == 1)
val statsDf = deltaLog.unsafeVolatileSnapshot.allFiles
.select(from_json($"stats", deltaLog.unsafeVolatileSnapshot.statsSchema)
.as("stats")).select("stats.*")
assert(statsDf.filter($"numRecords".isNull).count == 0)
assert(statsDf.agg(sum("numRecords")).as[Long].head() == simpleDF.count)
}
}

test("convert with collectStats true but config set to false -> Do not collect stats") {
withTempDir { dir =>
withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "false") {
val tempDir = dir.getCanonicalPath
writeFiles(tempDir, simpleDF)
convertToDelta(s"parquet.`$tempDir`", collectStats = true)
val deltaLog = DeltaLog.forTable(spark, tempDir)
val history = io.delta.tables.DeltaTable.forPath(tempDir).history()
checkAnswer(
spark.read.format("delta").load(tempDir),
simpleDF
)
assert(history.count == 1)
val statsDf = deltaLog.unsafeVolatileSnapshot.allFiles
.select(from_json($"stats", deltaLog.unsafeVolatileSnapshot.statsSchema)
.as("stats")).select("stats.*")
assert(statsDf.filter($"numRecords".isNotNull).count == 0)
}
}
}

test("negative case: convert a non-delta path falsely claimed as parquet") {
Seq("orc", "json", "csv").foreach { format =>
withTempDir { dir =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.spark.sql.delta.test.DeltaHiveTest

import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.functions.{col, from_json}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils

Expand All @@ -29,12 +30,13 @@ abstract class HiveConvertToDeltaSuiteBase

override protected def convertToDelta(
identifier: String,
partitionSchema: Option[String] = None): Unit = {
partitionSchema: Option[String] = None, collectStats: Boolean = true): Unit = {
if (partitionSchema.isEmpty) {
sql(s"convert to delta $identifier")
sql(s"convert to delta $identifier ${collectStatisticsStringOption(collectStats)} ")
} else {
val stringSchema = partitionSchema.get
sql(s"convert to delta $identifier partitioned by ($stringSchema) ")
sql(s"convert to delta $identifier ${collectStatisticsStringOption(collectStats)}" +
s" partitioned by ($stringSchema) ")
}
}

Expand All @@ -47,6 +49,56 @@ abstract class HiveConvertToDeltaSuiteBase
s"Table properties weren't empty for table $tableName: $cleanProps")
}

test("convert with statistics") {
val tbl = "hive_parquet"
withTable(tbl) {
sql(
s"""
|CREATE TABLE $tbl (id int, str string)
|PARTITIONED BY (part string)
|STORED AS PARQUET
""".stripMargin)

sql(s"insert into $tbl VALUES (1, 'a', 1)")

val catalogTable = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tbl))
convertToDelta(tbl, Some("part string"), collectStats = true)
val deltaLog = DeltaLog.forTable(spark, catalogTable)
val statsDf = deltaLog.unsafeVolatileSnapshot.allFiles
.select(from_json(col("stats"), deltaLog.unsafeVolatileSnapshot.statsSchema).as("stats"))
.select("stats.*")
assert(statsDf.filter(col("numRecords").isNull).count == 0)
val history = io.delta.tables.DeltaTable.forPath(catalogTable.location.getPath).history()
assert(history.count == 1)

}
}

test("convert without statistics") {
val tbl = "hive_parquet"
withTable(tbl) {
sql(
s"""
|CREATE TABLE $tbl (id int, str string)
|PARTITIONED BY (part string)
|STORED AS PARQUET
""".stripMargin)

sql(s"insert into $tbl VALUES (1, 'a', 1)")

val catalogTable = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tbl))
convertToDelta(tbl, Some("part string"), collectStats = false)
val deltaLog = DeltaLog.forTable(spark, catalogTable)
val statsDf = deltaLog.unsafeVolatileSnapshot.allFiles
.select(from_json(col("stats"), deltaLog.unsafeVolatileSnapshot.statsSchema).as("stats"))
.select("stats.*")
assert(statsDf.filter(col("numRecords").isNotNull).count == 0)
val history = io.delta.tables.DeltaTable.forPath(catalogTable.location.getPath).history()
assert(history.count == 1)

}
}

test("convert a Hive based parquet table") {
val tbl = "hive_parquet"
withTable(tbl) {
Expand Down
Loading

0 comments on commit a5fcec4

Please sign in to comment.