Skip to content

Commit

Permalink
[SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently, when we infer the schema for scala/java decimals, we return as data type the `SYSTEM_DEFAULT` implementation, ie. the decimal type with precision 38 and scale 18. But this is not right, as we know nothing about the right precision and scale and these values can be not enough to store the data. This problem arises in particular with UDF, where we cast all the input of type `DecimalType` to a `DecimalType(38, 18)`: in case this is not enough, null is returned as input for the UDF.

The PR defines a custom handling for casting to the expected data types for ScalaUDF: the decimal precision and scale is picked from the input, so no casting to different and maybe wrong percision and scale happens.

## How was this patch tested?

added UTs

Closes apache#23308 from mgaido91/SPARK-26308.

Authored-by: Marco Gaido <marcogaido91@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
mgaido91 authored and cloud-fan committed Dec 20, 2018
1 parent 04d8e3a commit 98c0ca7
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,37 @@ object TypeCoercion {
}
}
e.withNewChildren(children)

case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
val children = udf.children.zip(udf.inputTypes).map { case (in, expected) =>
implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in)
}
udf.withNewChildren(children)
}

private def udfInputToCastType(input: DataType, expectedType: DataType): DataType = {
(input, expectedType) match {
// SPARK-26308: avoid casting to an arbitrary precision and scale for decimals. Please note
// that precision and scale cannot be inferred properly for a ScalaUDF because, when it is
// created, it is not bound to any column. So here the precision and scale of the input
// column is used.
case (in: DecimalType, _: DecimalType) => in
case (ArrayType(dtIn, _), ArrayType(dtExp, nullableExp)) =>
ArrayType(udfInputToCastType(dtIn, dtExp), nullableExp)
case (MapType(keyDtIn, valueDtIn, _), MapType(keyDtExp, valueDtExp, nullableExp)) =>
MapType(udfInputToCastType(keyDtIn, keyDtExp),
udfInputToCastType(valueDtIn, valueDtExp),
nullableExp)
case (StructType(fieldsIn), StructType(fieldsExp)) =>
val fieldTypes =
fieldsIn.map(_.dataType).zip(fieldsExp.map(_.dataType)).map { case (dtIn, dtExp) =>
udfInputToCastType(dtIn, dtExp)
}
StructType(fieldsExp.zip(fieldTypes).map { case (field, newDt) =>
field.copy(dataType = newDt)
})
case (_, other) => other
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ case class ScalaUDF(
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
extends Expression with NonSQLExpression with UserDefinedExpression {

// The constructor for SPARK 2.1 and 2.2
def this(
Expand Down
32 changes: 31 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import java.math.BigDecimal

import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.QueryExecution
Expand All @@ -26,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationComm
import org.apache.spark.sql.functions.{lit, udf}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types.{DataTypes, DoubleType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.QueryExecutionListener


Expand Down Expand Up @@ -420,4 +422,32 @@ class UDFSuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
}
}

test("SPARK-26308: udf with decimal") {
val df1 = spark.createDataFrame(
sparkContext.parallelize(Seq(Row(new BigDecimal("2011000000000002456556")))),
StructType(Seq(StructField("col1", DecimalType(30, 0)))))
val udf1 = org.apache.spark.sql.functions.udf((value: BigDecimal) => {
if (value == null) null else value.toBigInteger.toString
})
checkAnswer(df1.select(udf1(df1.col("col1"))), Seq(Row("2011000000000002456556")))
}

test("SPARK-26308: udf with complex types of decimal") {
val df1 = spark.createDataFrame(
sparkContext.parallelize(Seq(Row(Array(new BigDecimal("2011000000000002456556"))))),
StructType(Seq(StructField("col1", ArrayType(DecimalType(30, 0))))))
val udf1 = org.apache.spark.sql.functions.udf((arr: Seq[BigDecimal]) => {
arr.map(value => if (value == null) null else value.toBigInteger.toString)
})
checkAnswer(df1.select(udf1($"col1")), Seq(Row(Array("2011000000000002456556"))))

val df2 = spark.createDataFrame(
sparkContext.parallelize(Seq(Row(Map("a" -> new BigDecimal("2011000000000002456556"))))),
StructType(Seq(StructField("col1", MapType(StringType, DecimalType(30, 0))))))
val udf2 = org.apache.spark.sql.functions.udf((map: Map[String, BigDecimal]) => {
map.mapValues(value => if (value == null) null else value.toBigInteger.toString)
})
checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> "2011000000000002456556"))))
}
}

0 comments on commit 98c0ca7

Please sign in to comment.