diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f5f6adc287c48..4bec87e552efb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -76,25 +76,14 @@ object DateTimeUtils { } } - private val threadLocalTimestampFormatter = new ThreadLocal[TimestampFormatter] { + private val threadLocalTimestampFormat = new ThreadLocal[TimestampFormatter] { override def initialValue(): TimestampFormatter = { TimestampFormatter("yyyy-MM-dd HH:mm:ss", TimeZoneUTC, Locale.US) } } - def getThreadLocalTimestampParser(timeZone: TimeZone): TimestampFormatter = { - val timestampFormatter = threadLocalTimestampParser.get() - timestampFormatter.withTimeZone(timeZone) - } - - private val threadLocalTimestampParser = new ThreadLocal[TimestampFormatter] { - override def initialValue(): TimestampFormatter = { - TimestampFormatter("yyyy-MM-dd HH:mm:ss[.S]", TimeZoneUTC, Locale.US) - } - } - - def getThreadLocalTimestampFormatter(timeZone: TimeZone): TimestampFormatter = { - val timestampFormatter = threadLocalTimestampFormatter.get() + def getThreadLocalTimestampFormat(timeZone: TimeZone): TimestampFormatter = { + val timestampFormatter = threadLocalTimestampFormat.get() timestampFormatter.withTimeZone(timeZone) } @@ -149,7 +138,7 @@ object DateTimeUtils { def timestampToString(us: SQLTimestamp, timeZone: TimeZone): String = { val ts = toJavaTimestamp(us) val timestampString = ts.toString - val timestampFormat = getThreadLocalTimestampFormatter(timeZone) + val timestampFormat = getThreadLocalTimestampFormat(timeZone) val formatted = timestampFormat.format(us) if (timestampString.length > 19 && timestampString.substring(19) != ".0") { @@ -1134,7 +1123,7 @@ object DateTimeUtils { */ private[util] def resetThreadLocals(): Unit = { threadLocalGmtCalendar.remove() - threadLocalTimestampFormatter.remove() + threadLocalTimestampFormat.remove() threadLocalDateFormat.remove() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 320583876690e..47f63e53725e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -58,6 +58,9 @@ object PartitionSpec { object PartitioningUtils { + val datePartitionPattern = "yyyy-MM-dd" + val timestampPartitionPattern = "yyyy-MM-dd HH:mm:ss[.S]" + private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { require(columnNames.size == literals.size) @@ -121,10 +124,12 @@ object PartitioningUtils { Map.empty[String, DataType] } + val dateFormatter = DateFormatter(datePartitionPattern, Locale.US) + val timestampFormatter = TimestampFormatter(timestampPartitionPattern, timeZone, Locale.US) // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, - validatePartitionColumns, timeZone) + validatePartitionColumns, timeZone, dateFormatter, timestampFormatter) }.unzip // We create pairs of (path -> path's partition value) here @@ -207,7 +212,9 @@ object PartitioningUtils { basePaths: Set[Path], userSpecifiedDataTypes: Map[String, DataType], validatePartitionColumns: Boolean, - timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = { + timeZone: TimeZone, + dateFormatter: DateFormatter, + timestampFormatter: TimestampFormatter): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -229,7 +236,7 @@ object PartitioningUtils { // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, - validatePartitionColumns, timeZone) + validatePartitionColumns, timeZone, dateFormatter, timestampFormatter) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -264,7 +271,9 @@ object PartitioningUtils { typeInference: Boolean, userSpecifiedDataTypes: Map[String, DataType], validatePartitionColumns: Boolean, - timeZone: TimeZone): Option[(String, Literal)] = { + timeZone: TimeZone, + dateFormatter: DateFormatter, + timestampFormatter: TimestampFormatter): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -279,7 +288,12 @@ object PartitioningUtils { // SPARK-26188: if user provides corresponding column schema, get the column value without // inference, and then cast it as user specified data type. val dataType = userSpecifiedDataTypes(columnName) - val columnValueLiteral = inferPartitionColumnValue(rawColumnValue, false, timeZone) + val columnValueLiteral = inferPartitionColumnValue( + rawColumnValue, + false, + timeZone, + dateFormatter, + timestampFormatter) val columnValue = columnValueLiteral.eval() val castedValue = Cast(columnValueLiteral, dataType, Option(timeZone.getID)).eval() if (validatePartitionColumns && columnValue != null && castedValue == null) { @@ -288,7 +302,12 @@ object PartitioningUtils { } Literal.create(castedValue, dataType) } else { - inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) + inferPartitionColumnValue( + rawColumnValue, + typeInference, + timeZone, + dateFormatter, + timestampFormatter) } Some(columnName -> literal) } @@ -441,7 +460,9 @@ object PartitioningUtils { private[datasources] def inferPartitionColumnValue( raw: String, typeInference: Boolean, - timeZone: TimeZone): Literal = { + timeZone: TimeZone, + dateFormatter: DateFormatter, + timestampFormatter: TimestampFormatter): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. val bigDecimal = new JBigDecimal(raw) @@ -456,7 +477,7 @@ object PartitioningUtils { val dateTry = Try { // try and parse the date, if no exception occurs this is a candidate to be resolved as // DateType - DateTimeUtils.getThreadLocalDateFormat().parse(raw) + dateFormatter.parse(raw) // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. // This can happen since DateFormat.parse may not use the entire text of the given string: // so if there are extra-characters after the date, it returns correctly. @@ -473,7 +494,7 @@ object PartitioningUtils { val unescapedRaw = unescapePathName(raw) // try and parse the date, if no exception occurs this is a candidate to be resolved as // TimestampType - DateTimeUtils.getThreadLocalTimestampParser(timeZone).parse(unescapedRaw) + timestampFormatter.parse(unescapedRaw) // SPARK-23436: see comment for date val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval() // Disallow TimestampType if the cast returned null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 88067358667c6..ad2325b01f524 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} import org.apache.spark.sql.execution.streaming.MemoryStream @@ -56,6 +56,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val timeZone = TimeZone.getDefault() val timeZoneId = timeZone.getID + val df = DateFormatter(datePartitionPattern, Locale.US) + val tf = TimestampFormatter(timestampPartitionPattern, timeZone, Locale.US) protected override def beforeAll(): Unit = { super.beforeAll() @@ -69,7 +71,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("column type inference") { def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { - assert(inferPartitionColumnValue(raw, true, timeZone) === literal) + assert(inferPartitionColumnValue(raw, true, timeZone, df, tf) === literal) } check("10", Literal.create(10, IntegerType)) @@ -197,13 +199,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { val actual = parsePartition(new Path(path), true, Set.empty[Path], - Map.empty, true, timeZone)._1 + Map.empty, true, timeZone, df, tf)._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone) + parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone, df, tf) }.getMessage assert(message.contains(expected)) @@ -249,7 +251,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha basePaths = Set(new Path("file://path/a=10")), Map.empty, true, - timeZone = timeZone)._1 + timeZone = timeZone, + df, + tf)._1 assert(partitionSpec1.isEmpty) @@ -260,7 +264,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha basePaths = Set(new Path("file://path")), Map.empty, true, - timeZone = timeZone)._1 + timeZone = timeZone, + df, + tf)._1 assert(partitionSpec2 == Option(PartitionValues(