Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22000][SQL] Address missing Upcast in JavaTypeInference.deserializerFor #23854

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._

object DeserializerBuildHelper {
/** Returns the current path with a sub-field extracted. */
def addToPath(
path: Expression,
part: String,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}

/** Returns the current path with a field at ordinal extracted. */
def addToPathOrdinal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only used in ScalaReflection? If so, how about moving this into there and make it private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be weird if we split out similar code (basically this only adds index handling on addToPath) to multiple places, but no strong opinion. WDYT? If you still feel better to move this to ScalaReflection, please let me know so that I can move it.

path: Expression,
ordinal: Int,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
val newPath = GetStructField(path, ordinal)
upCastToExpectedType(newPath, dataType, walkedTypePath)
}

def deserializerForWithNullSafety(
expr: Expression,
dataType: DataType,
nullable: Boolean,
walkedTypePath: Seq[String],
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
val newExpr = funcForCreatingNewExpr(expr, walkedTypePath)
expressionWithNullSafety(newExpr, nullable, walkedTypePath)
}

def deserializerForWithNullSafetyAndUpcast(
expr: Expression,
dataType: DataType,
nullable: Boolean,
walkedTypePath: Seq[String],
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath,
funcForCreatingNewExpr)
}

private def expressionWithNullSafety(
expr: Expression,
nullable: Boolean,
walkedTypePath: Seq[String]): Expression = {
if (nullable) {
expr
} else {
AssertNotNull(expr, walkedTypePath)
}
}

def createDeserializerForTypesSupportValueOf(
path: Expression,
clazz: Class[_]): Expression = {
StaticInvoke(
clazz,
ObjectType(clazz),
"valueOf",
path :: Nil,
returnNullable = false)
}

def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = {
Invoke(path, "toString", ObjectType(classOf[java.lang.String]),
returnNullable = returnNullable)
}

def createDeserializerForSqlDate(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
path :: Nil,
returnNullable = false)
}

def createDeserializerForSqlTimestamp(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
path :: Nil,
returnNullable = false)
}

def createDeserializerForJavaBigDecimal(
path: Expression,
returnNullable: Boolean): Expression = {
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
returnNullable = returnNullable)
}

def createDeserializerForScalaBigDecimal(
path: Expression,
returnNullable: Boolean): Expression = {
Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = returnNullable)
}

def createDeserializerForJavaBigInteger(
path: Expression,
returnNullable: Boolean): Expression = {
Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
returnNullable = returnNullable)
}

def createDeserializerForScalaBigInt(path: Expression): Expression = {
Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
returnNullable = false)
}

/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* only need to do this for leaf nodes.
*/
private def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
case _: MapType => expr
case _ => UpCast(expr, expected, walkedTypePath)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import scala.language.existentials

import com.google.common.reflect.TypeToken

import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
Expand Down Expand Up @@ -194,14 +195,20 @@ object JavaTypeInference {
*/
def deserializerFor(beanClass: Class[_]): Expression = {
val typeToken = TypeToken.of(beanClass)
deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1))
val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil
val (dataType, nullable) = inferDataType(typeToken)

// Assumes we are deserializing the first column of a row.
deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
nullable = nullable, walkedTypePath, (casted, walkedTypePath) => {
deserializerFor(typeToken, casted, walkedTypePath)
})
}

private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = UnresolvedExtractValue(path,
expressions.Literal(part))

private def deserializerFor(
typeToken: TypeToken[_],
path: Expression,
walkedTypePath: Seq[String]): Expression = {
typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path

Expand All @@ -212,82 +219,87 @@ object JavaTypeInference {
c == classOf[java.lang.Float] ||
c == classOf[java.lang.Byte] ||
c == classOf[java.lang.Boolean] =>
StaticInvoke(
c,
ObjectType(c),
"valueOf",
path :: Nil,
returnNullable = false)
createDeserializerForTypesSupportValueOf(path, c)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(c),
"toJavaDate",
path :: Nil,
returnNullable = false)
createDeserializerForSqlDate(path)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(c),
"toJavaTimestamp",
path :: Nil,
returnNullable = false)
createDeserializerForSqlTimestamp(path)

case c if c == classOf[java.lang.String] =>
Invoke(path, "toString", ObjectType(classOf[String]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ScalaReflection does the same thing, do we have a problem there too?

AFAIK the path should be a string type column, and it's always safe to call UTF8String.toString. My gut feeling is, we miss to add Upcast somewhere in JavaTypeInference.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the path should be a string type column

The sample code in JIRA issue tried to bind IntegerType column to String field in Java bean, which looks to break your expectation. (I guess ScalaReflection would not encounter this case.)

Spark doesn't throw error for this case though - actually Spark would show undefined behaviors, compilation failures on generated code, even might be possible to throw runtime exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sample code tried to bind IntegerType column to String field in Java bean

In scala, we can also do this and Spark will add Upcast. e.g. spark.range(1).as[String].collect works fine.

I did a quick search and JavaTypeInference has no Upcast. We should fix it and follow ScalaReflection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK. I'll check and address it. Maybe it would be a separate PR if it doesn't fix the new test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah your suggestion seems to work nicely! I left comment to ask which approach to choose: please compare both approach and comment. Thanks!

createDeserializerForString(path, returnNullable = true)

case c if c == classOf[java.math.BigDecimal] =>
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
createDeserializerForJavaBigDecimal(path, returnNullable = true)

case c if c == classOf[java.math.BigInteger] =>
createDeserializerForJavaBigInteger(path, returnNullable = true)

case c if c.isArray =>
val elementType = c.getComponentType
val primitiveMethod = elementType match {
case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
case c if c == java.lang.Byte.TYPE => Some("toByteArray")
case c if c == java.lang.Short.TYPE => Some("toShortArray")
case c if c == java.lang.Integer.TYPE => Some("toIntArray")
case c if c == java.lang.Long.TYPE => Some("toLongArray")
case c if c == java.lang.Float.TYPE => Some("toFloatArray")
case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
case _ => None
val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +:
walkedTypePath
val (dataType, elementNullable) = inferDataType(elementType)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
deserializerForWithNullSafetyAndUpcast(
element,
dataType,
nullable = elementNullable,
newTypePath,
(casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath))
}

primitiveMethod.map { method =>
Invoke(path, method, ObjectType(c))
}.getOrElse {
Invoke(
MapObjects(
p => deserializerFor(typeToken.getComponentType, p),
path,
inferDataType(elementType)._1),
"array",
ObjectType(c))
val arrayData = UnresolvedMapObjects(mapFunction, path)

val methodName = elementType match {
case c if c == java.lang.Integer.TYPE => "toIntArray"
case c if c == java.lang.Long.TYPE => "toLongArray"
case c if c == java.lang.Double.TYPE => "toDoubleArray"
case c if c == java.lang.Float.TYPE => "toFloatArray"
case c if c == java.lang.Short.TYPE => "toShortArray"
case c if c == java.lang.Byte.TYPE => "toByteArray"
case c if c == java.lang.Boolean.TYPE => "toBooleanArray"
// non-primitive
case _ => "array"
}
Invoke(arrayData, methodName, ObjectType(c))

case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
UnresolvedMapObjects(
p => deserializerFor(et, p),
path,
customCollectionCls = Some(c))
val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +:
walkedTypePath
val (dataType, elementNullable) = inferDataType(et)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
deserializerForWithNullSafetyAndUpcast(
element,
dataType,
nullable = elementNullable,
newTypePath,
(casted, typePath) => deserializerFor(et, casted, typePath))
}

UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))

case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" +
s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath

val keyData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(keyType, p),
p => deserializerFor(keyType, p, newTypePath),
MapKeys(path)),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(valueType, p),
p => deserializerFor(valueType, p, newTypePath),
MapValues(path)),
"array",
ObjectType(classOf[Array[Any]]))
Expand All @@ -300,25 +312,25 @@ object JavaTypeInference {
returnNullable = false)

case other if other.isEnum =>
StaticInvoke(
other,
ObjectType(other),
"valueOf",
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
returnNullable = false)
createDeserializerForTypesSupportValueOf(
createDeserializerForString(path, returnNullable = false),
other)

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (_, nullable) = inferDataType(fieldType)
val constructor = deserializerFor(fieldType, addToPath(fieldName))
val setter = if (nullable) {
constructor
} else {
AssertNotNull(constructor, Seq("currently no type path record in java"))
}
val (dataType, nullable) = inferDataType(fieldType)
val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" +
s""", name: "$fieldName")""") +: walkedTypePath
val setter = deserializerForWithNullSafety(
path,
dataType,
nullable = nullable,
newTypePath,
(expr, typePath) => deserializerFor(fieldType,
addToPath(expr, fieldName, dataType, typePath), typePath))
p.getWriteMethod.getName -> setter
}.toMap

Expand Down
Loading