Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
sadikovi committed Nov 19, 2021
1 parent f9d097c commit 522f7de
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,9 @@ private[sql] class JSONOptions(
s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]"
})

val timestampNTZFormatInRead: Option[String] = parameters.get("timestampNTZFormat").orElse {
if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) {
Some(s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSS")
} else {
None
}
}
val timestampNTZFormatInWrite: String = parameters.getOrElse("timestampNTZFormat",
if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) {
s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSS"
} else {
s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]"
})
val timestampNTZFormatInRead: Option[String] = parameters.get("timestampNTZFormat")
val timestampNTZFormatInWrite: String =
parameters.getOrElse("timestampNTZFormat", s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]")

val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {
if (options.prefersDecimal && decimalTry.isDefined) {
decimalTry.get
} else if (options.inferTimestamp &&
(allCatch opt !timestampNTZFormatter.isTimeZoneSet(field)).getOrElse(false) &&
(allCatch opt timestampNTZFormatter.parseWithoutTimeZone(field)).isDefined) {
SQLConf.get.timestampType
} else if (options.inferTimestamp &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.{LegacyDateFormat, LENIENT_SIMPLE_DATE_FORMAT}
import org.apache.spark.sql.catalyst.util.RebaseDateTime._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.{Decimal, TimestampNTZType}
import org.apache.spark.unsafe.types.UTF8String

sealed trait TimestampFormatter extends Serializable {
Expand Down Expand Up @@ -71,19 +72,6 @@ sealed trait TimestampFormatter extends Serializable {
s"The method `parseWithoutTimeZone(s: String)` should be implemented in the formatter " +
"of timestamp without time zone")

/**
* Returns true if the parsed timestamp contains the time zone component, false otherwise.
* Used to determine if the timestamp can be inferred as timestamp without time zone.
*
* @param s - string with timestamp to inspect
* @return whether the timestamp string has the time zone component defined.
*/
@throws(classOf[IllegalStateException])
def isTimeZoneSet(s: String): Boolean =
throw new IllegalStateException(
s"The method `isTimeZoneSet(s: String)` should be implemented in the formatter " +
"of timestamp without time zone")

def format(us: Long): String
def format(ts: Timestamp): String
def format(instant: Instant): String
Expand Down Expand Up @@ -134,20 +122,16 @@ class Iso8601TimestampFormatter(
override def parseWithoutTimeZone(s: String): Long = {
try {
val parsed = formatter.parse(s)
val parsedZoneId = parsed.query(TemporalQueries.zone())
if (parsedZoneId != null) {
throw QueryExecutionErrors.cannotParseStringAsDataTypeError(pattern, s, TimestampNTZType)
}
val localDate = toLocalDate(parsed)
val localTime = toLocalTime(parsed)
DateTimeUtils.localDateTimeToMicros(LocalDateTime.of(localDate, localTime))
} catch checkParsedDiff(s, legacyFormatter.parse)
}

override def isTimeZoneSet(s: String): Boolean = {
try {
val parsed = formatter.parse(s)
val parsedZoneId = parsed.query(TemporalQueries.zone())
parsedZoneId != null
} catch checkParsedDiff(s, legacyFormatter.isTimeZoneSet)
}

override def format(instant: Instant): String = {
try {
formatter.withZone(zoneId).format(instant)
Expand Down Expand Up @@ -209,16 +193,15 @@ class DefaultTimestampFormatter(

override def parseWithoutTimeZone(s: String): Long = {
try {
DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(UTF8String.fromString(s))
val utf8Value = UTF8String.fromString(s)
val (_, zoneIdOpt, _) = DateTimeUtils.parseTimestampString(utf8Value)
if (zoneIdOpt.isDefined) {
throw QueryExecutionErrors.cannotParseStringAsDataTypeError(
TimestampFormatter.defaultPattern(), s, TimestampNTZType)
}
DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utf8Value)
} catch checkParsedDiff(s, legacyFormatter.parse)
}

override def isTimeZoneSet(s: String): Boolean = {
try {
val (_, zoneIdOpt, _) = parseTimestampString(UTF8String.fromString(s))
zoneIdOpt.isDefined
} catch checkParsedDiff(s, legacyFormatter.isTimeZoneSet)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,13 @@ object QueryExecutionErrors {
s"[$token] as target spark data type [$dataType].")
}

def cannotParseStringAsDataTypeError(pattern: String, value: String, dataType: DataType)
: Throwable = {
new RuntimeException(
s"Cannot parse field value ${value} for pattern ${pattern} " +
s"as target spark data type [$dataType].")
}

def failToParseEmptyStringForDataTypeError(dataType: DataType): Throwable = {
new RuntimeException(
s"Failed to parse an empty string for data type ${dataType.catalogString}")
Expand Down Expand Up @@ -1890,4 +1897,3 @@ object QueryExecutionErrors {
new UnsupportedOperationException(s"Hive table $tableName with ANSI intervals is not supported")
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -2861,6 +2861,33 @@ abstract class JsonSuite
}
}

test("SPARK-37326: Malformed records when reading TIMESTAMP_LTZ as TIMESTAMP_NTZ") {
withTempDir { dir =>
val path = s"${dir.getCanonicalPath}/json"

Seq(
"""{"col0": "2020-12-12T12:12:12.000"}""",
"""{"col0": "2020-12-12T12:12:12.000Z"}""",
"""{"col0": "2020-12-12T12:12:12.000+05:00"}""",
"""{"col0": "2020-12-12T12:12:12.000"}"""
).toDF("data")
.coalesce(1)
.write.text(path)

val res = spark.read.schema("col0 TIMESTAMP_NTZ").json(path)

checkAnswer(
res,
Seq(
Row(LocalDateTime.of(2020, 12, 12, 12, 12, 12)),
Row(null),
Row(null),
Row(LocalDateTime.of(2020, 12, 12, 12, 12, 12))
)
)
}
}

test("SPARK-37326: Fail to write TIMESTAMP_NTZ if timestampNTZFormat contains zone offset") {
val patterns = Seq(
"yyyy-MM-dd HH:mm:ss XXX",
Expand Down

0 comments on commit 522f7de

Please sign in to comment.