diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 8c88957ca8e1c..0332b8243cf02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -31,7 +31,7 @@ case class JSONOptions( allowUnquotedFieldNames: Boolean = false, allowSingleQuotes: Boolean = true, allowNumericLeadingZeros: Boolean = false, - allowNonNumericNumbers: Boolean = false) { + allowNonNumericNumbers: Boolean = true) { /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -59,6 +59,6 @@ object JSONOptions { allowNumericLeadingZeros = parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(false) + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 3e61ba35bea8e..b7f9592f7cd50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.CharArrayWriter -import com.fasterxml.jackson.core.JsonFactory +import com.fasterxml.jackson.core.{JsonGenerator, JsonFactory} import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} @@ -161,12 +161,13 @@ private[sql] class JSONRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val quoteNonNumeric = !options.allowNonNumericNumbers new OutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, dataSchema, context) + new JsonOutputWriter(path, dataSchema, context, quoteNonNumeric) } } } @@ -175,12 +176,15 @@ private[sql] class JSONRelation( private[json] class JsonOutputWriter( path: String, dataSchema: StructType, - context: TaskAttemptContext) + context: TaskAttemptContext, + quoteNonNumeric: Boolean) extends OutputWriter with SparkHadoopMapRedUtil with Logging { private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records - private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private[this] val factory = new JsonFactory() + factory.configure(JsonGenerator.Feature.QUOTE_NON_NUMERIC_NUMBERS, quoteNonNumeric) + private[this] val gen = factory.createGenerator(writer).setRootValueSeparator(null) private[this] val result = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 0bfc7635ed79b..eea80a331261a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -106,18 +106,7 @@ object JacksonParser { parser.getDoubleValue case (VALUE_STRING, DoubleType) => - // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase() - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toDouble - } else { - sys.error(s"Cannot parse $value as DoubleType.") - } + parser.getDoubleValue case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => Decimal(parser.getDecimalValue, dt.precision, dt.scale) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 522670395152f..094a77d631f8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -94,19 +94,28 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { } test("allowNonNumericNumbers off") { - val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val testCases: Seq[String] = Seq("""{"age": NaN}""", """{"age": Infinity}""", + """{"age": -Infinity}""") - assert(df.schema.head.name == "_corrupt_record") + testCases.foreach { str => + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNonNumericNumbers", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } } test("allowNonNumericNumbers on") { - val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + val testCases: Seq[String] = Seq("""{"age": NaN}""", """{"age": Infinity}""", + """{"age": -Infinity}""") + val tests: Seq[Double => Boolean] = Seq(_.isNaN, _.isPosInfinity, _.isNegInfinity) - assert(df.schema.head.name == "age") - assert(df.first().getDouble(0).isNaN) + testCases.zipWithIndex.foreach { case (str, idx) => + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "age") + assert(tests(idx)(df.first().getDouble(0))) + } } }