Skip to content

Commit

Permalink
Deal with quoted non-numeric number.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 20, 2015
1 parent b2a835d commit 186fa5e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,21 @@ private[json] object InferSchema {
// record fields' types have been combined.
NullType

case VALUE_STRING => StringType
case VALUE_STRING =>
// When JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS is enabled,
// we need to do special handling for quoted non-numeric numbers.
if (configOptions.allowNonNumericNumbers) {
val value = parser.getText
val lowerCaseValue = value.toLowerCase()
if (lowerCaseValue.equals("nan") ||
lowerCaseValue.equals("infinity") ||
lowerCaseValue.equals("-infinity") ||
lowerCaseValue.equals("inf") ||
lowerCaseValue.equals("-inf")) {
return DoubleType
}
}
StringType
case START_OBJECT =>
val builder = Seq.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.json

import java.io.CharArrayWriter

import com.fasterxml.jackson.core.{JsonGenerator, JsonFactory}
import com.fasterxml.jackson.core.JsonFactory
import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
Expand Down Expand Up @@ -161,13 +161,12 @@ private[sql] class JSONRelation(
}

override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val quoteNonNumeric = !options.allowNonNumericNumbers
new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new JsonOutputWriter(path, dataSchema, context, quoteNonNumeric)
new JsonOutputWriter(path, dataSchema, context)
}
}
}
Expand All @@ -176,15 +175,12 @@ private[sql] class JSONRelation(
private[json] class JsonOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
quoteNonNumeric: Boolean)
context: TaskAttemptContext)
extends OutputWriter with SparkHadoopMapRedUtil with Logging {

private[this] val writer = new CharArrayWriter()
// create the Generator without separator inserted between 2 records
private[this] val factory = new JsonFactory()
factory.configure(JsonGenerator.Feature.QUOTE_NON_NUMERIC_NUMBERS, quoteNonNumeric)
private[this] val gen = factory.createGenerator(writer).setRootValueSeparator(null)
private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
private[this] val result = new Text()

private val recordWriter: RecordWriter[NullWritable, Text] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ object JacksonParser {
def convertField(
factory: JsonFactory,
parser: JsonParser,
schema: DataType): Any = {
schema: DataType,
configOptions: JSONOptions): Any = {
import com.fasterxml.jackson.core.JsonToken._
(parser.getCurrentToken, schema) match {
case (null | VALUE_NULL, _) =>
null

case (FIELD_NAME, _) =>
parser.nextToken()
convertField(factory, parser, schema)
convertField(factory, parser, schema, configOptions)

case (VALUE_STRING, StringType) =>
UTF8String.fromString(parser.getText)
Expand Down Expand Up @@ -106,7 +107,22 @@ object JacksonParser {
parser.getDoubleValue

case (VALUE_STRING, DoubleType) =>
parser.getDoubleValue
// Special case handling for quoted non-numeric numbers.
if (configOptions.allowNonNumericNumbers) {
val value = parser.getText
val lowerCaseValue = value.toLowerCase()
if (lowerCaseValue.equals("nan") ||
lowerCaseValue.equals("infinity") ||
lowerCaseValue.equals("-infinity") ||
lowerCaseValue.equals("inf") ||
lowerCaseValue.equals("-inf")) {
value.toDouble
} else {
sys.error(s"Cannot parse $value as DoubleType.")
}
} else {
parser.getDoubleValue
}

case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) =>
Decimal(parser.getDecimalValue, dt.precision, dt.scale)
Expand All @@ -130,26 +146,26 @@ object JacksonParser {
false

case (START_OBJECT, st: StructType) =>
convertObject(factory, parser, st)
convertObject(factory, parser, st, configOptions)

case (START_ARRAY, st: StructType) =>
// SPARK-3308: support reading top level JSON arrays and take every element
// in such an array as a row
convertArray(factory, parser, st)
convertArray(factory, parser, st, configOptions)

case (START_ARRAY, ArrayType(st, _)) =>
convertArray(factory, parser, st)
convertArray(factory, parser, st, configOptions)

case (START_OBJECT, ArrayType(st, _)) =>
// the business end of SPARK-3308:
// when an object is found but an array is requested just wrap it in a list
convertField(factory, parser, st) :: Nil
convertField(factory, parser, st, configOptions) :: Nil

case (START_OBJECT, MapType(StringType, kt, _)) =>
convertMap(factory, parser, kt)
convertMap(factory, parser, kt, configOptions)

case (_, udt: UserDefinedType[_]) =>
convertField(factory, parser, udt.sqlType)
convertField(factory, parser, udt.sqlType, configOptions)

case (token, dataType) =>
sys.error(s"Failed to parse a value for data type $dataType (current token: $token).")
Expand All @@ -164,12 +180,13 @@ object JacksonParser {
private def convertObject(
factory: JsonFactory,
parser: JsonParser,
schema: StructType): InternalRow = {
schema: StructType,
configOptions: JSONOptions): InternalRow = {
val row = new GenericMutableRow(schema.length)
while (nextUntil(parser, JsonToken.END_OBJECT)) {
schema.getFieldIndex(parser.getCurrentName) match {
case Some(index) =>
row.update(index, convertField(factory, parser, schema(index).dataType))
row.update(index, convertField(factory, parser, schema(index).dataType, configOptions))

case None =>
parser.skipChildren()
Expand All @@ -185,23 +202,25 @@ object JacksonParser {
private def convertMap(
factory: JsonFactory,
parser: JsonParser,
valueType: DataType): MapData = {
valueType: DataType,
configOptions: JSONOptions): MapData = {
val keys = ArrayBuffer.empty[UTF8String]
val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
keys += UTF8String.fromString(parser.getCurrentName)
values += convertField(factory, parser, valueType)
values += convertField(factory, parser, valueType, configOptions)
}
ArrayBasedMapData(keys.toArray, values.toArray)
}

private def convertArray(
factory: JsonFactory,
parser: JsonParser,
elementType: DataType): ArrayData = {
elementType: DataType,
configOptions: JSONOptions): ArrayData = {
val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
values += convertField(factory, parser, elementType)
values += convertField(factory, parser, elementType, configOptions)
}

new GenericArrayData(values.toArray)
Expand Down Expand Up @@ -235,7 +254,7 @@ object JacksonParser {
Utils.tryWithResource(factory.createParser(record)) { parser =>
parser.nextToken()

convertField(factory, parser, schema) match {
convertField(factory, parser, schema, configOptions) match {
case null => failedRecord(record)
case row: InternalRow => row :: Nil
case array: ArrayData =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {

test("allowNonNumericNumbers on") {
val testCases: Seq[String] = Seq("""{"age": NaN}""", """{"age": Infinity}""",
"""{"age": -Infinity}""")
val tests: Seq[Double => Boolean] = Seq(_.isNaN, _.isPosInfinity, _.isNegInfinity)
"""{"age": -Infinity}""", """{"age": "NaN"}""", """{"age": "Infinity"}""",
"""{"age": "-Infinity"}""")
val tests: Seq[Double => Boolean] = Seq(_.isNaN, _.isPosInfinity, _.isNegInfinity,
_.isNaN, _.isPosInfinity, _.isNegInfinity)

testCases.zipWithIndex.foreach { case (str, idx) =>
val rdd = sqlContext.sparkContext.parallelize(Seq(str))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {

Utils.tryWithResource(factory.createParser(writer.toString)) { parser =>
parser.nextToken()
JacksonParser.convertField(factory, parser, dataType)
JacksonParser.convertField(factory, parser, dataType, JSONOptions())
}
}

Expand Down

0 comments on commit 186fa5e

Please sign in to comment.