Skip to content

Commit

Permalink
Not to quote non numeric numbers when enabling allowNonNumericNumbers.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 17, 2015
1 parent 2777677 commit b2a835d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ case class JSONOptions(
allowUnquotedFieldNames: Boolean = false,
allowSingleQuotes: Boolean = true,
allowNumericLeadingZeros: Boolean = false,
allowNonNumericNumbers: Boolean = false) {
allowNonNumericNumbers: Boolean = true) {

/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
Expand Down Expand Up @@ -59,6 +59,6 @@ object JSONOptions {
allowNumericLeadingZeros =
parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false),
allowNonNumericNumbers =
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(false)
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.json

import java.io.CharArrayWriter

import com.fasterxml.jackson.core.JsonFactory
import com.fasterxml.jackson.core.{JsonGenerator, JsonFactory}
import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
Expand Down Expand Up @@ -161,12 +161,13 @@ private[sql] class JSONRelation(
}

override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val quoteNonNumeric = !options.allowNonNumericNumbers
new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new JsonOutputWriter(path, dataSchema, context)
new JsonOutputWriter(path, dataSchema, context, quoteNonNumeric)
}
}
}
Expand All @@ -175,12 +176,15 @@ private[sql] class JSONRelation(
private[json] class JsonOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext)
context: TaskAttemptContext,
quoteNonNumeric: Boolean)
extends OutputWriter with SparkHadoopMapRedUtil with Logging {

private[this] val writer = new CharArrayWriter()
// create the Generator without separator inserted between 2 records
private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
private[this] val factory = new JsonFactory()
factory.configure(JsonGenerator.Feature.QUOTE_NON_NUMERIC_NUMBERS, quoteNonNumeric)
private[this] val gen = factory.createGenerator(writer).setRootValueSeparator(null)
private[this] val result = new Text()

private val recordWriter: RecordWriter[NullWritable, Text] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,7 @@ object JacksonParser {
parser.getDoubleValue

case (VALUE_STRING, DoubleType) =>
// Special case handling for NaN and Infinity.
val value = parser.getText
val lowerCaseValue = value.toLowerCase()
if (lowerCaseValue.equals("nan") ||
lowerCaseValue.equals("infinity") ||
lowerCaseValue.equals("-infinity") ||
lowerCaseValue.equals("inf") ||
lowerCaseValue.equals("-inf")) {
value.toDouble
} else {
sys.error(s"Cannot parse $value as DoubleType.")
}
parser.getDoubleValue

case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) =>
Decimal(parser.getDecimalValue, dt.precision, dt.scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,28 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
}

test("allowNonNumericNumbers off") {
val str = """{"age": NaN}"""
val rdd = sqlContext.sparkContext.parallelize(Seq(str))
val df = sqlContext.read.json(rdd)
val testCases: Seq[String] = Seq("""{"age": NaN}""", """{"age": Infinity}""",
"""{"age": -Infinity}""")

assert(df.schema.head.name == "_corrupt_record")
testCases.foreach { str =>
val rdd = sqlContext.sparkContext.parallelize(Seq(str))
val df = sqlContext.read.option("allowNonNumericNumbers", "false").json(rdd)

assert(df.schema.head.name == "_corrupt_record")
}
}

test("allowNonNumericNumbers on") {
val str = """{"age": NaN}"""
val rdd = sqlContext.sparkContext.parallelize(Seq(str))
val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd)
val testCases: Seq[String] = Seq("""{"age": NaN}""", """{"age": Infinity}""",
"""{"age": -Infinity}""")
val tests: Seq[Double => Boolean] = Seq(_.isNaN, _.isPosInfinity, _.isNegInfinity)

assert(df.schema.head.name == "age")
assert(df.first().getDouble(0).isNaN)
testCases.zipWithIndex.foreach { case (str, idx) =>
val rdd = sqlContext.sparkContext.parallelize(Seq(str))
val df = sqlContext.read.json(rdd)

assert(df.schema.head.name == "age")
assert(tests(idx)(df.first().getDouble(0)))
}
}
}

0 comments on commit b2a835d

Please sign in to comment.