Skip to content

Commit

Permalink
[SPARK-26163][SQL] Parsing decimals from JSON using locale
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

In the PR, I propose using of the locale option to parse (and infer) decimals from JSON input. After the changes, `JacksonParser` converts input string to `BigDecimal` and to Spark's Decimal by using `java.text.DecimalFormat`. New behaviour can be switched off via SQL config `spark.sql.legacy.decimalParsing.enabled`.

## How was this patch tested?

Added 2 tests to `JsonExpressionsSuite` for the `en-US`, `ko-KR`, `ru-RU`, `de-DE` locales:
- Inferring decimal type using locale from JSON field values
- Converting JSON field values to specified decimal type using the locales.

Closes apache#23132 from MaxGekk/json-decimal-parsing-locale.

Lead-authored-by: Maxim Gekk <max.gekk@gmail.com>
Co-authored-by: Maxim Gekk <maxim.gekk@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and cloud-fan committed Nov 29, 2018
1 parent 8bfea86 commit 7a83d71
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.catalyst.expressions

import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
Expand Down Expand Up @@ -83,4 +86,22 @@ object ExprUtils {
}
}
}

def getDecimalParser(locale: Locale): String => java.math.BigDecimal = {
if (locale == Locale.US) { // Special handling the default locale for backward compatibility
(s: String) => new java.math.BigDecimal(s.replaceAll(",", ""))
} else {
val decimalFormat = new DecimalFormat("", new DecimalFormatSymbols(locale))
decimalFormat.setParseBigDecimal(true)
(s: String) => {
val pos = new ParsePosition(0)
val result = decimalFormat.parse(s, pos).asInstanceOf[java.math.BigDecimal]
if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) {
throw new IllegalArgumentException("Cannot parse any decimal");
} else {
result
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ import scala.util.parsing.combinator.RegexParsers

import com.fasterxml.jackson.core._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -775,6 +773,9 @@ case class SchemaOfJson(
factory
}

@transient
private lazy val jsonInferSchema = new JsonInferSchema(jsonOptions)

@transient
private lazy val json = child.eval().asInstanceOf[UTF8String]

Expand All @@ -787,7 +788,7 @@ case class SchemaOfJson(
override def eval(v: InternalRow): Any = {
val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
parser.nextToken()
inferField(parser, jsonOptions)
jsonInferSchema.inferField(parser)
}

UTF8String.fromString(dt.catalogString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -135,6 +136,8 @@ class JacksonParser(
}
}

private val decimalParser = ExprUtils.getDecimalParser(options.locale)

/**
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema.
Expand Down Expand Up @@ -261,6 +264,9 @@ class JacksonParser(
(parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) {
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) =>
Decimal(parser.getDecimalValue, dt.precision, dt.scale)
case VALUE_STRING if parser.getTextLength >= 1 =>
val bigDecimal = decimalParser(parser.getText)
Decimal(bigDecimal, dt.precision, dt.scale)
}

case st: StructType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@ package org.apache.spark.sql.catalyst.json

import java.util.Comparator

import scala.util.control.Exception.allCatch

import com.fasterxml.jackson.core._

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

private[sql] object JsonInferSchema {
private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {

private val decimalParser = ExprUtils.getDecimalParser(options.locale)

/**
* Infer the type of a collection of json records in three stages:
Expand All @@ -40,21 +45,20 @@ private[sql] object JsonInferSchema {
*/
def infer[T](
json: RDD[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
val parseMode = options.parseMode
val columnNameOfCorruptRecord = options.columnNameOfCorruptRecord

// In each RDD partition, perform schema inference on each row and merge afterwards.
val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
val typeMerger = JsonInferSchema.compatibleRootType(columnNameOfCorruptRecord, parseMode)
val mergedTypesFromPartitions = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
options.setJacksonOptions(factory)
iter.flatMap { row =>
try {
Utils.tryWithResource(createParser(factory, row)) { parser =>
parser.nextToken()
Some(inferField(parser, configOptions))
Some(inferField(parser))
}
} catch {
case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
Expand Down Expand Up @@ -82,42 +86,25 @@ private[sql] object JsonInferSchema {
}
json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult)

canonicalizeType(rootType, configOptions) match {
canonicalizeType(rootType, options) match {
case Some(st: StructType) => st
case _ =>
// canonicalizeType erases all empty structs, including the only one we want to keep
StructType(Nil)
}
}

private[this] val structFieldComparator = new Comparator[StructField] {
override def compare(o1: StructField, o2: StructField): Int = {
o1.name.compareTo(o2.name)
}
}

private def isSorted(arr: Array[StructField]): Boolean = {
var i: Int = 0
while (i < arr.length - 1) {
if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
return false
}
i += 1
}
true
}

/**
* Infer the type of a json document from the parser's token stream
*/
def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
def inferField(parser: JsonParser): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType

case FIELD_NAME =>
parser.nextToken()
inferField(parser, configOptions)
inferField(parser)

case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
Expand All @@ -128,18 +115,25 @@ private[sql] object JsonInferSchema {
// record fields' types have been combined.
NullType

case VALUE_STRING if options.prefersDecimal =>
val decimalTry = allCatch opt {
val bigDecimal = decimalParser(parser.getText)
DecimalType(bigDecimal.precision, bigDecimal.scale)
}
decimalTry.getOrElse(StringType)
case VALUE_STRING => StringType

case START_OBJECT =>
val builder = Array.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
inferField(parser, configOptions),
inferField(parser),
nullable = true)
}
val fields: Array[StructField] = builder.result()
// Note: other code relies on this sorting for correctness, so don't remove it!
java.util.Arrays.sort(fields, structFieldComparator)
java.util.Arrays.sort(fields, JsonInferSchema.structFieldComparator)
StructType(fields)

case START_ARRAY =>
Expand All @@ -148,15 +142,15 @@ private[sql] object JsonInferSchema {
// the type as we pass through all JSON objects.
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(
elementType, inferField(parser, configOptions))
elementType = JsonInferSchema.compatibleType(
elementType, inferField(parser))
}

ArrayType(elementType)

case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if options.primitivesAsString => StringType

case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
case (VALUE_TRUE | VALUE_FALSE) if options.primitivesAsString => StringType

case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
import JsonParser.NumberType._
Expand All @@ -172,7 +166,7 @@ private[sql] object JsonInferSchema {
} else {
DoubleType
}
case FLOAT | DOUBLE if configOptions.prefersDecimal =>
case FLOAT | DOUBLE if options.prefersDecimal =>
val v = parser.getDecimalValue
if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
DecimalType(Math.max(v.precision(), v.scale()), v.scale())
Expand Down Expand Up @@ -217,20 +211,39 @@ private[sql] object JsonInferSchema {

case other => Some(other)
}
}

object JsonInferSchema {
val structFieldComparator = new Comparator[StructField] {
override def compare(o1: StructField, o2: StructField): Int = {
o1.name.compareTo(o2.name)
}
}

def isSorted(arr: Array[StructField]): Boolean = {
var i: Int = 0
while (i < arr.length - 1) {
if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
return false
}
i += 1
}
true
}

private def withCorruptField(
def withCorruptField(
struct: StructType,
other: DataType,
columnNameOfCorruptRecords: String,
parseMode: ParseMode) = parseMode match {
parseMode: ParseMode): StructType = parseMode match {
case PermissiveMode =>
// If we see any other data type at the root level, we get records that cannot be
// parsed. So, we use the struct as the data type and add the corrupt field to the schema.
if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
// If this given struct does not have a column used for corrupt records,
// add this field.
val newFields: Array[StructField] =
StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
// Note: other code relies on this sorting for correctness, so don't remove it!
java.util.Arrays.sort(newFields, structFieldComparator)
StructType(newFields)
Expand All @@ -253,7 +266,7 @@ private[sql] object JsonInferSchema {
/**
* Remove top-level ArrayType wrappers and merge the remaining schemas
*/
private def compatibleRootType(
def compatibleRootType(
columnNameOfCorruptRecords: String,
parseMode: ParseMode): (DataType, DataType) => DataType = {
// Since we support array of json objects at the top level,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.text.SimpleDateFormat
import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat}
import java.util.{Calendar, Locale}

import org.scalatest.exceptions.TestFailedException
Expand Down Expand Up @@ -765,4 +765,44 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
timeZoneId = gmtId),
expectedErrMsg = "The field for corrupt records must be string type and nullable")
}

def decimalInput(langTag: String): (Decimal, String) = {
val decimalVal = new java.math.BigDecimal("1000.001")
val decimalType = new DecimalType(10, 5)
val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale)
val decimalFormat = new DecimalFormat("",
new DecimalFormatSymbols(Locale.forLanguageTag(langTag)))
val input = s"""{"d": "${decimalFormat.format(expected.toBigDecimal)}"}"""

(expected, input)
}

test("parse decimals using locale") {
def checkDecimalParsing(langTag: String): Unit = {
val schema = new StructType().add("d", DecimalType(10, 5))
val options = Map("locale" -> langTag)
val (expected, input) = decimalInput(langTag)

checkEvaluation(
JsonToStructs(schema, options, Literal.create(input), gmtId),
InternalRow(expected))
}

Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing)
}

test("inferring the decimal type using locale") {
def checkDecimalInfer(langTag: String, expectedType: String): Unit = {
val options = Map("locale" -> langTag, "prefersDecimal" -> "true")
val (_, input) = decimalInput(langTag)

checkEvaluation(
SchemaOfJson(Literal.create(input), options),
expectedType)
}

Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach {
checkDecimalInfer(_, """struct<d:decimal(7,3)>""")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ object TextInputJsonDataSource extends JsonDataSource {
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))

SQLExecution.withSQLConfPropagated(json.sparkSession) {
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
new JsonInferSchema(parsedOptions).infer(rdd, rowParser)
}
}

Expand Down Expand Up @@ -166,7 +166,7 @@ object MultiLineJsonDataSource extends JsonDataSource {
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))

SQLExecution.withSQLConfPropagated(sparkSession) {
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
new JsonInferSchema(parsedOptions).infer[PortableDataStream](sampled, parser)
}
}

Expand Down
Loading

0 comments on commit 7a83d71

Please sign in to comment.