Skip to content

Commit

Permalink
Moving partition timestamp and date parsers to PartitioningUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxGekk committed Dec 28, 2018
1 parent f0a9fe7 commit 5f0b0a3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,14 @@ object DateTimeUtils {
}
}

private val threadLocalTimestampFormatter = new ThreadLocal[TimestampFormatter] {
private val threadLocalTimestampFormat = new ThreadLocal[TimestampFormatter] {
override def initialValue(): TimestampFormatter = {
TimestampFormatter("yyyy-MM-dd HH:mm:ss", TimeZoneUTC, Locale.US)
}
}

def getThreadLocalTimestampParser(timeZone: TimeZone): TimestampFormatter = {
val timestampFormatter = threadLocalTimestampParser.get()
timestampFormatter.withTimeZone(timeZone)
}

private val threadLocalTimestampParser = new ThreadLocal[TimestampFormatter] {
override def initialValue(): TimestampFormatter = {
TimestampFormatter("yyyy-MM-dd HH:mm:ss[.S]", TimeZoneUTC, Locale.US)
}
}

def getThreadLocalTimestampFormatter(timeZone: TimeZone): TimestampFormatter = {
val timestampFormatter = threadLocalTimestampFormatter.get()
def getThreadLocalTimestampFormat(timeZone: TimeZone): TimestampFormatter = {
val timestampFormatter = threadLocalTimestampFormat.get()
timestampFormatter.withTimeZone(timeZone)
}

Expand Down Expand Up @@ -149,7 +138,7 @@ object DateTimeUtils {
def timestampToString(us: SQLTimestamp, timeZone: TimeZone): String = {
val ts = toJavaTimestamp(us)
val timestampString = ts.toString
val timestampFormat = getThreadLocalTimestampFormatter(timeZone)
val timestampFormat = getThreadLocalTimestampFormat(timeZone)
val formatted = timestampFormat.format(us)

if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
Expand Down Expand Up @@ -1134,7 +1123,7 @@ object DateTimeUtils {
*/
private[util] def resetThreadLocals(): Unit = {
threadLocalGmtCalendar.remove()
threadLocalTimestampFormatter.remove()
threadLocalTimestampFormat.remove()
threadLocalDateFormat.remove()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

Expand All @@ -58,6 +58,9 @@ object PartitionSpec {

object PartitioningUtils {

val datePartitionPattern = "yyyy-MM-dd"
val timestampPartitionPattern = "yyyy-MM-dd HH:mm:ss[.S]"

private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal])
{
require(columnNames.size == literals.size)
Expand Down Expand Up @@ -121,10 +124,12 @@ object PartitioningUtils {
Map.empty[String, DataType]
}

val dateFormatter = DateFormatter(datePartitionPattern, Locale.US)
val timestampFormatter = TimestampFormatter(timestampPartitionPattern, timeZone, Locale.US)
// First, we need to parse every partition's path and see if we can find partition values.
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes,
validatePartitionColumns, timeZone)
validatePartitionColumns, timeZone, dateFormatter, timestampFormatter)
}.unzip

// We create pairs of (path -> path's partition value) here
Expand Down Expand Up @@ -207,7 +212,9 @@ object PartitioningUtils {
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
validatePartitionColumns: Boolean,
timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
timeZone: TimeZone,
dateFormatter: DateFormatter,
timestampFormatter: TimestampFormatter): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, Literal)]
// Old Hadoop versions don't have `Path.isRoot`
var finished = path.getParent == null
Expand All @@ -229,7 +236,7 @@ object PartitioningUtils {
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes,
validatePartitionColumns, timeZone)
validatePartitionColumns, timeZone, dateFormatter, timestampFormatter)
maybeColumn.foreach(columns += _)

// Now, we determine if we should stop.
Expand Down Expand Up @@ -264,7 +271,9 @@ object PartitioningUtils {
typeInference: Boolean,
userSpecifiedDataTypes: Map[String, DataType],
validatePartitionColumns: Boolean,
timeZone: TimeZone): Option[(String, Literal)] = {
timeZone: TimeZone,
dateFormatter: DateFormatter,
timestampFormatter: TimestampFormatter): Option[(String, Literal)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
None
Expand All @@ -279,7 +288,12 @@ object PartitioningUtils {
// SPARK-26188: if user provides corresponding column schema, get the column value without
// inference, and then cast it as user specified data type.
val dataType = userSpecifiedDataTypes(columnName)
val columnValueLiteral = inferPartitionColumnValue(rawColumnValue, false, timeZone)
val columnValueLiteral = inferPartitionColumnValue(
rawColumnValue,
false,
timeZone,
dateFormatter,
timestampFormatter)
val columnValue = columnValueLiteral.eval()
val castedValue = Cast(columnValueLiteral, dataType, Option(timeZone.getID)).eval()
if (validatePartitionColumns && columnValue != null && castedValue == null) {
Expand All @@ -288,7 +302,12 @@ object PartitioningUtils {
}
Literal.create(castedValue, dataType)
} else {
inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
inferPartitionColumnValue(
rawColumnValue,
typeInference,
timeZone,
dateFormatter,
timestampFormatter)
}
Some(columnName -> literal)
}
Expand Down Expand Up @@ -441,7 +460,9 @@ object PartitioningUtils {
private[datasources] def inferPartitionColumnValue(
raw: String,
typeInference: Boolean,
timeZone: TimeZone): Literal = {
timeZone: TimeZone,
dateFormatter: DateFormatter,
timestampFormatter: TimestampFormatter): Literal = {
val decimalTry = Try {
// `BigDecimal` conversion can fail when the `field` is not a form of number.
val bigDecimal = new JBigDecimal(raw)
Expand All @@ -456,7 +477,7 @@ object PartitioningUtils {
val dateTry = Try {
// try and parse the date, if no exception occurs this is a candidate to be resolved as
// DateType
DateTimeUtils.getThreadLocalDateFormat().parse(raw)
dateFormatter.parse(raw)
// SPARK-23436: Casting the string to date may still return null if a bad Date is provided.
// This can happen since DateFormat.parse may not use the entire text of the given string:
// so if there are extra-characters after the date, it returns correctly.
Expand All @@ -473,7 +494,7 @@ object PartitioningUtils {
val unescapedRaw = unescapePathName(raw)
// try and parse the date, if no exception occurs this is a candidate to be resolved as
// TimestampType
DateTimeUtils.getThreadLocalTimestampParser(timeZone).parse(unescapedRaw)
timestampFormatter.parse(unescapedRaw)
// SPARK-23436: see comment for date
val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval()
// Disallow TimestampType if the cast returned null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition}
import org.apache.spark.sql.execution.streaming.MemoryStream
Expand All @@ -56,6 +56,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha

val timeZone = TimeZone.getDefault()
val timeZoneId = timeZone.getID
val df = DateFormatter(datePartitionPattern, Locale.US)
val tf = TimestampFormatter(timestampPartitionPattern, timeZone, Locale.US)

protected override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -69,7 +71,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha

test("column type inference") {
def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = {
assert(inferPartitionColumnValue(raw, true, timeZone) === literal)
assert(inferPartitionColumnValue(raw, true, timeZone, df, tf) === literal)
}

check("10", Literal.create(10, IntegerType))
Expand Down Expand Up @@ -197,13 +199,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("parse partition") {
def check(path: String, expected: Option[PartitionValues]): Unit = {
val actual = parsePartition(new Path(path), true, Set.empty[Path],
Map.empty, true, timeZone)._1
Map.empty, true, timeZone, df, tf)._1
assert(expected === actual)
}

def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
val message = intercept[T] {
parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone)
parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone, df, tf)
}.getMessage

assert(message.contains(expected))
Expand Down Expand Up @@ -249,7 +251,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
basePaths = Set(new Path("file://path/a=10")),
Map.empty,
true,
timeZone = timeZone)._1
timeZone = timeZone,
df,
tf)._1

assert(partitionSpec1.isEmpty)

Expand All @@ -260,7 +264,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
basePaths = Set(new Path("file://path")),
Map.empty,
true,
timeZone = timeZone)._1
timeZone = timeZone,
df,
tf)._1

assert(partitionSpec2 ==
Option(PartitionValues(
Expand Down

0 comments on commit 5f0b0a3

Please sign in to comment.