-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22000][SQL] Address missing Upcast in JavaTypeInference.deserializerFor #23854
Changes from all commits
acd5d55
443e74f
d869bba
60ca088
9eeb391
d4c2060
9c040e2
2007452
4d564ab
e01bfe6
8cbad26
7ff01db
24a1b19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst | ||
|
||
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast} | ||
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke} | ||
import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
import org.apache.spark.sql.types._ | ||
|
||
object DeserializerBuildHelper { | ||
/** Returns the current path with a sub-field extracted. */ | ||
def addToPath( | ||
path: Expression, | ||
part: String, | ||
dataType: DataType, | ||
walkedTypePath: Seq[String]): Expression = { | ||
val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) | ||
upCastToExpectedType(newPath, dataType, walkedTypePath) | ||
} | ||
|
||
/** Returns the current path with a field at ordinal extracted. */ | ||
def addToPathOrdinal( | ||
path: Expression, | ||
ordinal: Int, | ||
dataType: DataType, | ||
walkedTypePath: Seq[String]): Expression = { | ||
val newPath = GetStructField(path, ordinal) | ||
upCastToExpectedType(newPath, dataType, walkedTypePath) | ||
} | ||
|
||
def deserializerForWithNullSafety( | ||
expr: Expression, | ||
dataType: DataType, | ||
nullable: Boolean, | ||
walkedTypePath: Seq[String], | ||
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { | ||
val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) | ||
expressionWithNullSafety(newExpr, nullable, walkedTypePath) | ||
} | ||
|
||
def deserializerForWithNullSafetyAndUpcast( | ||
expr: Expression, | ||
dataType: DataType, | ||
nullable: Boolean, | ||
walkedTypePath: Seq[String], | ||
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { | ||
val casted = upCastToExpectedType(expr, dataType, walkedTypePath) | ||
deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, | ||
funcForCreatingNewExpr) | ||
} | ||
|
||
private def expressionWithNullSafety( | ||
expr: Expression, | ||
nullable: Boolean, | ||
walkedTypePath: Seq[String]): Expression = { | ||
if (nullable) { | ||
expr | ||
} else { | ||
AssertNotNull(expr, walkedTypePath) | ||
} | ||
} | ||
|
||
def createDeserializerForTypesSupportValueOf( | ||
path: Expression, | ||
clazz: Class[_]): Expression = { | ||
StaticInvoke( | ||
clazz, | ||
ObjectType(clazz), | ||
"valueOf", | ||
path :: Nil, | ||
returnNullable = false) | ||
} | ||
|
||
def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = { | ||
Invoke(path, "toString", ObjectType(classOf[java.lang.String]), | ||
returnNullable = returnNullable) | ||
} | ||
|
||
def createDeserializerForSqlDate(path: Expression): Expression = { | ||
StaticInvoke( | ||
DateTimeUtils.getClass, | ||
ObjectType(classOf[java.sql.Date]), | ||
"toJavaDate", | ||
path :: Nil, | ||
returnNullable = false) | ||
} | ||
|
||
def createDeserializerForSqlTimestamp(path: Expression): Expression = { | ||
StaticInvoke( | ||
DateTimeUtils.getClass, | ||
ObjectType(classOf[java.sql.Timestamp]), | ||
"toJavaTimestamp", | ||
path :: Nil, | ||
returnNullable = false) | ||
} | ||
|
||
def createDeserializerForJavaBigDecimal( | ||
path: Expression, | ||
returnNullable: Boolean): Expression = { | ||
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), | ||
returnNullable = returnNullable) | ||
} | ||
|
||
def createDeserializerForScalaBigDecimal( | ||
path: Expression, | ||
returnNullable: Boolean): Expression = { | ||
Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = returnNullable) | ||
} | ||
|
||
def createDeserializerForJavaBigInteger( | ||
path: Expression, | ||
returnNullable: Boolean): Expression = { | ||
Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), | ||
returnNullable = returnNullable) | ||
} | ||
|
||
def createDeserializerForScalaBigInt(path: Expression): Expression = { | ||
Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), | ||
returnNullable = false) | ||
} | ||
|
||
/** | ||
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff | ||
* and lost the required data type, which may lead to runtime error if the real type doesn't | ||
* match the encoder's schema. | ||
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type | ||
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class | ||
* `Data` with int and long, because we lost the information that `b` should be a string. | ||
* | ||
* This method help us "remember" the required data type by adding a `UpCast`. Note that we | ||
* only need to do this for leaf nodes. | ||
*/ | ||
private def upCastToExpectedType( | ||
expr: Expression, | ||
expected: DataType, | ||
walkedTypePath: Seq[String]): Expression = expected match { | ||
case _: StructType => expr | ||
case _: ArrayType => expr | ||
case _: MapType => expr | ||
case _ => UpCast(expr, expected, walkedTypePath) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,8 @@ import scala.language.existentials | |
|
||
import com.google.common.reflect.TypeToken | ||
|
||
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} | ||
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ | ||
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.objects._ | ||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} | ||
|
@@ -194,14 +195,20 @@ object JavaTypeInference { | |
*/ | ||
def deserializerFor(beanClass: Class[_]): Expression = { | ||
val typeToken = TypeToken.of(beanClass) | ||
deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1)) | ||
val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil | ||
val (dataType, nullable) = inferDataType(typeToken) | ||
|
||
// Assumes we are deserializing the first column of a row. | ||
deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, | ||
nullable = nullable, walkedTypePath, (casted, walkedTypePath) => { | ||
deserializerFor(typeToken, casted, walkedTypePath) | ||
}) | ||
} | ||
|
||
private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = { | ||
/** Returns the current path with a sub-field extracted. */ | ||
def addToPath(part: String): Expression = UnresolvedExtractValue(path, | ||
expressions.Literal(part)) | ||
|
||
private def deserializerFor( | ||
typeToken: TypeToken[_], | ||
path: Expression, | ||
walkedTypePath: Seq[String]): Expression = { | ||
typeToken.getRawType match { | ||
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path | ||
|
||
|
@@ -212,82 +219,87 @@ object JavaTypeInference { | |
c == classOf[java.lang.Float] || | ||
c == classOf[java.lang.Byte] || | ||
c == classOf[java.lang.Boolean] => | ||
StaticInvoke( | ||
c, | ||
ObjectType(c), | ||
"valueOf", | ||
path :: Nil, | ||
returnNullable = false) | ||
createDeserializerForTypesSupportValueOf(path, c) | ||
|
||
case c if c == classOf[java.sql.Date] => | ||
StaticInvoke( | ||
DateTimeUtils.getClass, | ||
ObjectType(c), | ||
"toJavaDate", | ||
path :: Nil, | ||
returnNullable = false) | ||
createDeserializerForSqlDate(path) | ||
|
||
case c if c == classOf[java.sql.Timestamp] => | ||
StaticInvoke( | ||
DateTimeUtils.getClass, | ||
ObjectType(c), | ||
"toJavaTimestamp", | ||
path :: Nil, | ||
returnNullable = false) | ||
createDeserializerForSqlTimestamp(path) | ||
|
||
case c if c == classOf[java.lang.String] => | ||
Invoke(path, "toString", ObjectType(classOf[String])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The AFAIK the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The sample code in JIRA issue tried to bind IntegerType column to String field in Java bean, which looks to break your expectation. (I guess ScalaReflection would not encounter this case.) Spark doesn't throw error for this case though - actually Spark would show undefined behaviors, compilation failures on generated code, even might be possible to throw runtime exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In scala, we can also do this and Spark will add I did a quick search and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah OK. I'll check and address it. Maybe it would be a separate PR if it doesn't fix the new test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah your suggestion seems to work nicely! I left comment to ask which approach to choose: please compare both approach and comment. Thanks! |
||
createDeserializerForString(path, returnNullable = true) | ||
|
||
case c if c == classOf[java.math.BigDecimal] => | ||
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) | ||
createDeserializerForJavaBigDecimal(path, returnNullable = true) | ||
|
||
case c if c == classOf[java.math.BigInteger] => | ||
createDeserializerForJavaBigInteger(path, returnNullable = true) | ||
|
||
case c if c.isArray => | ||
val elementType = c.getComponentType | ||
val primitiveMethod = elementType match { | ||
case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") | ||
case c if c == java.lang.Byte.TYPE => Some("toByteArray") | ||
case c if c == java.lang.Short.TYPE => Some("toShortArray") | ||
case c if c == java.lang.Integer.TYPE => Some("toIntArray") | ||
case c if c == java.lang.Long.TYPE => Some("toLongArray") | ||
case c if c == java.lang.Float.TYPE => Some("toFloatArray") | ||
case c if c == java.lang.Double.TYPE => Some("toDoubleArray") | ||
case _ => None | ||
val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: | ||
walkedTypePath | ||
val (dataType, elementNullable) = inferDataType(elementType) | ||
val mapFunction: Expression => Expression = element => { | ||
// upcast the array element to the data type the encoder expected. | ||
deserializerForWithNullSafetyAndUpcast( | ||
element, | ||
dataType, | ||
nullable = elementNullable, | ||
newTypePath, | ||
(casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) | ||
} | ||
|
||
primitiveMethod.map { method => | ||
Invoke(path, method, ObjectType(c)) | ||
}.getOrElse { | ||
Invoke( | ||
MapObjects( | ||
p => deserializerFor(typeToken.getComponentType, p), | ||
path, | ||
inferDataType(elementType)._1), | ||
"array", | ||
ObjectType(c)) | ||
val arrayData = UnresolvedMapObjects(mapFunction, path) | ||
|
||
val methodName = elementType match { | ||
case c if c == java.lang.Integer.TYPE => "toIntArray" | ||
case c if c == java.lang.Long.TYPE => "toLongArray" | ||
case c if c == java.lang.Double.TYPE => "toDoubleArray" | ||
case c if c == java.lang.Float.TYPE => "toFloatArray" | ||
case c if c == java.lang.Short.TYPE => "toShortArray" | ||
case c if c == java.lang.Byte.TYPE => "toByteArray" | ||
case c if c == java.lang.Boolean.TYPE => "toBooleanArray" | ||
// non-primitive | ||
case _ => "array" | ||
} | ||
Invoke(arrayData, methodName, ObjectType(c)) | ||
|
||
case c if listType.isAssignableFrom(typeToken) => | ||
val et = elementType(typeToken) | ||
UnresolvedMapObjects( | ||
p => deserializerFor(et, p), | ||
path, | ||
customCollectionCls = Some(c)) | ||
val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +: | ||
walkedTypePath | ||
val (dataType, elementNullable) = inferDataType(et) | ||
val mapFunction: Expression => Expression = element => { | ||
// upcast the array element to the data type the encoder expected. | ||
deserializerForWithNullSafetyAndUpcast( | ||
element, | ||
dataType, | ||
nullable = elementNullable, | ||
newTypePath, | ||
(casted, typePath) => deserializerFor(et, casted, typePath)) | ||
} | ||
|
||
UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) | ||
|
||
case _ if mapType.isAssignableFrom(typeToken) => | ||
val (keyType, valueType) = mapKeyValueType(typeToken) | ||
val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" + | ||
s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath | ||
|
||
val keyData = | ||
Invoke( | ||
UnresolvedMapObjects( | ||
p => deserializerFor(keyType, p), | ||
p => deserializerFor(keyType, p, newTypePath), | ||
MapKeys(path)), | ||
"array", | ||
ObjectType(classOf[Array[Any]])) | ||
|
||
val valueData = | ||
Invoke( | ||
UnresolvedMapObjects( | ||
p => deserializerFor(valueType, p), | ||
p => deserializerFor(valueType, p, newTypePath), | ||
MapValues(path)), | ||
"array", | ||
ObjectType(classOf[Array[Any]])) | ||
|
@@ -300,25 +312,25 @@ object JavaTypeInference { | |
returnNullable = false) | ||
|
||
case other if other.isEnum => | ||
StaticInvoke( | ||
other, | ||
ObjectType(other), | ||
"valueOf", | ||
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, | ||
returnNullable = false) | ||
createDeserializerForTypesSupportValueOf( | ||
createDeserializerForString(path, returnNullable = false), | ||
other) | ||
|
||
case other => | ||
val properties = getJavaBeanReadableAndWritableProperties(other) | ||
val setters = properties.map { p => | ||
val fieldName = p.getName | ||
val fieldType = typeToken.method(p.getReadMethod).getReturnType | ||
val (_, nullable) = inferDataType(fieldType) | ||
val constructor = deserializerFor(fieldType, addToPath(fieldName)) | ||
val setter = if (nullable) { | ||
constructor | ||
} else { | ||
AssertNotNull(constructor, Seq("currently no type path record in java")) | ||
} | ||
val (dataType, nullable) = inferDataType(fieldType) | ||
val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" + | ||
s""", name: "$fieldName")""") +: walkedTypePath | ||
val setter = deserializerForWithNullSafety( | ||
path, | ||
dataType, | ||
nullable = nullable, | ||
newTypePath, | ||
(expr, typePath) => deserializerFor(fieldType, | ||
addToPath(expr, fieldName, dataType, typePath), typePath)) | ||
p.getWriteMethod.getName -> setter | ||
}.toMap | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is only used in
ScalaReflection
? If so, how about moving this into there and make it private?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be weird if we split out similar code (basically this only adds index handling on
addToPath
) to multiple places, but no strong opinion. WDYT? If you still feel better to move this to ScalaReflection, please let me know so that I can move it.