Skip to content

Commit

Permalink
[SPARK-26188][SQL] FileIndex: don't infer data types of partition col…
Browse files Browse the repository at this point in the history
…umns if user specifies schema

## What changes were proposed in this pull request?

This PR is to fix a regression introduced in: https://github.com/apache/spark/pull/21004/files#r236998030

If user specifies schema, Spark don't need to infer data type for of partition columns, otherwise the data type might not match with the one user provided.
E.g. for partition directory `p=4d`, after data type inference  the column value will be `4.0`.
See https://issues.apache.org/jira/browse/SPARK-26188 for more details.

Note that user specified schema **might not cover all the data columns**:
```
val schema = new StructType()
  .add("id", StringType)
  .add("ex", ArrayType(StringType))
val df = spark.read
  .schema(schema)
  .format("parquet")
  .load(src.toString)

assert(df.schema.toList === List(
  StructField("ex", ArrayType(StringType)),
  StructField("part", IntegerType), // inferred partitionColumn dataType
  StructField("id", StringType))) // used user provided partitionColumn dataType
```
For the missing columns in user specified schema, Spark still need to infer their data types if `partitionColumnTypeInferenceEnabled` is enabled.

To implement the partially inference, refactor `PartitioningUtils.parsePartitions`  and pass the user specified schema as parameter to cast partition values.

## How was this patch tested?

Add unit test.

Closes apache#23165 from gengliangwang/fixFileIndex.

Authored-by: Gengliang Wang <gengliang.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 9cfc3ee)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
gengliangwang authored and kai-chi committed Jul 23, 2019
1 parent 7319292 commit 1016375
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,33 +126,15 @@ abstract class PartitioningAwareFileIndex(
val caseInsensitiveOptions = CaseInsensitiveMap(parameters)
val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
val inferredPartitionSpec = PartitioningUtils.parsePartitions(

val caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis
PartitioningUtils.parsePartitions(
leafDirs,
typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
basePaths = basePaths,
userSpecifiedSchema = userSpecifiedSchema,
caseSensitive = caseSensitive,
timeZoneId = timeZoneId)
userSpecifiedSchema match {
case Some(userProvidedSchema) if userProvidedSchema.nonEmpty =>
val userPartitionSchema =
combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec)

// we need to cast into the data type that user specified.
def castPartitionValuesToUserSchema(row: InternalRow) = {
InternalRow((0 until row.numFields).map { i =>
val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType
Cast(
Literal.create(row.get(i, dt), dt),
userPartitionSchema.fields(i).dataType,
Option(timeZoneId)).eval()
}: _*)
}

PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part =>
part.copy(values = castPartitionValuesToUserSchema(part.values))
})
case _ =>
inferredPartitionSpec
}
}

private def prunePartitions(
Expand Down Expand Up @@ -233,25 +215,6 @@ abstract class PartitioningAwareFileIndex(
val name = path.getName
!((name.startsWith("_") && !name.contains("=")) || name.startsWith("."))
}

/**
* In the read path, only managed tables by Hive provide the partition columns properly when
* initializing this class. All other file based data sources will try to infer the partitioning,
* and then cast the inferred types to user specified dataTypes if the partition columns exist
* inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or
* inconsistent data types as reported in SPARK-21463.
* @param spec A partition inference result
* @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema`
*/
private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = {
val equality = sparkSession.sessionState.conf.resolver
val resolved = spec.partitionColumns.map { partitionField =>
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
partitionField)
}
StructType(resolved)
}
}

object PartitioningAwareFileIndex {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

Expand Down Expand Up @@ -94,18 +94,34 @@ object PartitioningUtils {
paths: Seq[Path],
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
timeZoneId: String): PartitionSpec = {
parsePartitions(paths, typeInference, basePaths, DateTimeUtils.getTimeZone(timeZoneId))
parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema,
caseSensitive, DateTimeUtils.getTimeZone(timeZoneId))
}

private[datasources] def parsePartitions(
paths: Seq[Path],
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
timeZone: TimeZone): PartitionSpec = {
val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) {
val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap
if (!caseSensitive) {
CaseInsensitiveMap(nameToDataType)
} else {
nameToDataType
}
} else {
Map.empty[String, DataType]
}

// First, we need to parse every partition's path and see if we can find partition values.
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
parsePartition(path, typeInference, basePaths, timeZone)
parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, timeZone)
}.unzip

// We create pairs of (path -> path's partition value) here
Expand Down Expand Up @@ -147,7 +163,7 @@ object PartitioningUtils {
columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
// We always assume partition columns are nullable since we've no idea whether null values
// will be appended in the future.
StructField(name, dataType, nullable = true)
StructField(name, userSpecifiedDataTypes.getOrElse(name, dataType), nullable = true)
}
}

Expand Down Expand Up @@ -185,6 +201,7 @@ object PartitioningUtils {
path: Path,
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, Literal)]
// Old Hadoop versions don't have `Path.isRoot`
Expand All @@ -206,7 +223,7 @@ object PartitioningUtils {
// Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
parsePartitionColumn(currentPath.getName, typeInference, timeZone)
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, timeZone)
maybeColumn.foreach(columns += _)

// Now, we determine if we should stop.
Expand Down Expand Up @@ -239,6 +256,7 @@ object PartitioningUtils {
private def parsePartitionColumn(
columnSpec: String,
typeInference: Boolean,
userSpecifiedDataTypes: Map[String, DataType],
timeZone: TimeZone): Option[(String, Literal)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
Expand All @@ -250,7 +268,16 @@ object PartitioningUtils {
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")

val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
val literal = if (userSpecifiedDataTypes.contains(columnName)) {
// SPARK-26188: if user provides corresponding column schema, get the column value without
// inference, and then cast it as user specified data type.
val columnValue = inferPartitionColumnValue(rawColumnValue, false, timeZone)
val castedValue =
Cast(columnValue, userSpecifiedDataTypes(columnName), Option(timeZone.getID)).eval()
Literal.create(castedValue, userSpecifiedDataTypes(columnName))
} else {
inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
}
Some(columnName -> literal)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator}

class FileIndexSuite extends SharedSQLContext {
Expand All @@ -49,6 +50,21 @@ class FileIndexSuite extends SharedSQLContext {
}
}

test("SPARK-26188: don't infer data types of partition columns if user specifies schema") {
withTempDir { dir =>
val partitionDirectory = new File(dir, s"a=4d")
partitionDirectory.mkdir()
val file = new File(partitionDirectory, "text.txt")
stringToFile(file, "text")
val path = new Path(dir.getCanonicalPath)
val schema = StructType(Seq(StructField("a", StringType, false)))
val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema))
val partitionValues = fileIndex.partitionSpec().partitions.map(_.values)
assert(partitionValues.length == 1 && partitionValues(0).numFields == 1 &&
partitionValues(0).getString(0) == "4d")
}
}

test("InMemoryFileIndex: input paths are converted to qualified paths") {
withTempDir { dir =>
val file = new File(dir, "text.txt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
"hdfs://host:9000/path/a=10.5/b=hello")

var exception = intercept[AssertionError] {
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId)
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))

Expand All @@ -115,6 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/")),
None,
true,
timeZoneId)

// Valid
Expand All @@ -128,6 +130,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/something=true/table")),
None,
true,
timeZoneId)

// Valid
Expand All @@ -141,6 +145,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/table=true")),
None,
true,
timeZoneId)

// Invalid
Expand All @@ -154,6 +160,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/")),
None,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
Expand All @@ -174,20 +182,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/tmp/tables/")),
None,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
}

test("parse partition") {
def check(path: String, expected: Option[PartitionValues]): Unit = {
val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1
val actual = parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)._1
assert(expected === actual)
}

def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
val message = intercept[T] {
parsePartition(new Path(path), true, Set.empty[Path], timeZone)
parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)
}.getMessage

assert(message.contains(expected))
Expand Down Expand Up @@ -231,6 +241,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
path = new Path("file://path/a=10"),
typeInference = true,
basePaths = Set(new Path("file://path/a=10")),
Map.empty,
timeZone = timeZone)._1

assert(partitionSpec1.isEmpty)
Expand All @@ -240,6 +251,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
path = new Path("file://path/a=10"),
typeInference = true,
basePaths = Set(new Path("file://path")),
Map.empty,
timeZone = timeZone)._1

assert(partitionSpec2 ==
Expand All @@ -258,6 +270,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
rootPaths,
None,
true,
timeZoneId)
assert(actualSpec.partitionColumns === spec.partitionColumns)
assert(actualSpec.partitions.length === spec.partitions.length)
Expand Down Expand Up @@ -370,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("parse partitions with type inference disabled") {
def check(paths: Seq[String], spec: PartitionSpec): Unit = {
val actualSpec =
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId)
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, true, timeZoneId)
assert(actualSpec === spec)
}

Expand Down

0 comments on commit 1016375

Please sign in to comment.