From acd5d55404439a582c161f54b3f7b144da31147e Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 21 Feb 2019 18:05:54 +0900 Subject: [PATCH 01/13] [SPARK-22000][SQL] Use String.valueOf to assign value to String type of field in Java Bean --- .../sql/catalyst/JavaTypeInference.scala | 6 +- .../sql/JavaBeanDeserializationSuite.java | 100 ++++++++++++++++++ .../test/resources/test-data/spark-22000.csv | 5 + 3 files changed, 107 insertions(+), 4 deletions(-) create mode 100755 sql/core/src/test/resources/test-data/spark-22000.csv diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 311060e5961cb..b28570e3fff91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -211,7 +211,8 @@ object JavaTypeInference { c == classOf[java.lang.Double] || c == classOf[java.lang.Float] || c == classOf[java.lang.Byte] || - c == classOf[java.lang.Boolean] => + c == classOf[java.lang.Boolean] || + c == classOf[java.lang.String] => StaticInvoke( c, ObjectType(c), @@ -235,9 +236,6 @@ object JavaTypeInference { path :: Nil, returnNullable = false) - case c if c == classOf[java.lang.String] => - Invoke(path, "toString", ObjectType(classOf[String])) - case c if c == classOf[java.math.BigDecimal] => Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 8f35abeb579b5..ae68c7aa63613 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -115,6 +115,37 @@ public void testBeanWithMapFieldsDeserialization() { Assert.assertEquals(records, MAP_RECORDS); } + private static final List RECORDS_SPARK_22000 = new ArrayList<>(); + + static { + RECORDS_SPARK_22000.add(new RecordSpark22000("1", "j123@aaa.com", 2, 11)); + RECORDS_SPARK_22000.add(new RecordSpark22000("2", "j123@aaa.com", 3, 12)); + RECORDS_SPARK_22000.add(new RecordSpark22000("3", "j123@aaa.com", 4, 13)); + RECORDS_SPARK_22000.add(new RecordSpark22000("4", "j123@aaa.com", 5, 14)); + } + + @Test + public void testSpark22000() { + // Here we try to convert the type of 'ref' field, from integer to string. + // Before applying SPARK-22000, Spark called toString() against variable which type might be primitive. + // SPARK-22000 it calls String.valueOf() which finally calls toString() but handles boxing + // if the type is primitive. + Encoder encoder = Encoders.bean(RecordSpark22000.class); + + Dataset dataset = spark + .read() + .format("csv") + .option("header", "true") + .option("mode", "DROPMALFORMED") + .schema("ref int, userId string, x int, y int") + .load("src/test/resources/test-data/spark-22000.csv") + .as(encoder); + + List records = dataset.collectAsList(); + + Assert.assertEquals(records, RECORDS_SPARK_22000); + } + public static class ArrayRecord { private int id; @@ -252,4 +283,73 @@ public String toString() { return String.format("[%d,%d]", startTime, endTime); } } + + public static class RecordSpark22000 { + private String ref; + private String userId; + private int x; + private int y; + + public RecordSpark22000() { } + + RecordSpark22000(String ref, String userId, int x, int y) { + this.ref = ref; + this.userId = userId; + this.x = x; + this.y = y; + } + + public String getRef() { + return ref; + } + + public void setRef(String ref) { + this.ref = ref; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public int getX() { + return x; + } + + public void setX(int x) { + this.x = x; + } + + public int getY() { + return y; + } + + public void setY(int y) { + this.y = y; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecordSpark22000 that = (RecordSpark22000) o; + return x == that.x && + y == that.y && + Objects.equals(ref, that.ref) && + Objects.equals(userId, that.userId); + } + + @Override + public int hashCode() { + return Objects.hash(ref, userId, x, y); + } + + @Override + public String toString() { + return String.format("ref='%s', userId='%s', x=%d, y=%d", ref, userId, x, y); + } + } } diff --git a/sql/core/src/test/resources/test-data/spark-22000.csv b/sql/core/src/test/resources/test-data/spark-22000.csv new file mode 100755 index 0000000000000..06deb6f293352 --- /dev/null +++ b/sql/core/src/test/resources/test-data/spark-22000.csv @@ -0,0 +1,5 @@ +ref,userId,x,y +1,j123@aaa.com,2,11 +2,j123@aaa.com,3,12 +3,j123@aaa.com,4,13 +4,j123@aaa.com,5,14 \ No newline at end of file From 443e74f207d4869d11944f4d95df97f5bc40e6bc Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 21 Feb 2019 23:03:13 +0900 Subject: [PATCH 02/13] Address review comments from maropu --- .../sql/JavaBeanDeserializationSuite.java | 175 +++++++++++++----- .../test/resources/test-data/spark-22000.csv | 5 - 2 files changed, 130 insertions(+), 50 deletions(-) delete mode 100755 sql/core/src/test/resources/test-data/spark-22000.csv diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index ae68c7aa63613..09426e1cf0d73 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -20,6 +20,15 @@ import java.io.Serializable; import java.util.*; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverterSuite; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.UTF8String; import org.junit.*; import org.apache.spark.sql.Dataset; @@ -115,35 +124,68 @@ public void testBeanWithMapFieldsDeserialization() { Assert.assertEquals(records, MAP_RECORDS); } + private static final List ROWS_SPARK_22000 = new ArrayList<>(); private static final List RECORDS_SPARK_22000 = new ArrayList<>(); + private static Row createRecordSpark22000Row(Long index) { + Object[] values = new Object[] { + index.shortValue(), + index.intValue(), + index, + index.floatValue(), + index.doubleValue(), + String.valueOf(index), + index % 2 == 0, + new java.sql.Timestamp(System.currentTimeMillis()) + }; + return new GenericRow(values); + } + + private static RecordSpark22000 createRecordSpark22000(Row recordRow) { + RecordSpark22000 record = new RecordSpark22000(); + record.setShortField(String.valueOf(recordRow.getShort(0))); + record.setIntField(String.valueOf(recordRow.getInt(1))); + record.setLongField(String.valueOf(recordRow.getLong(2))); + record.setFloatField(String.valueOf(recordRow.getFloat(3))); + record.setDoubleField(String.valueOf(recordRow.getDouble(4))); + record.setStringField(recordRow.getString(5)); + record.setBooleanField(String.valueOf(recordRow.getBoolean(6))); + record.setTimestampField(String.valueOf(recordRow.getTimestamp(7).getTime() * 1000)); + return record; + } + static { - RECORDS_SPARK_22000.add(new RecordSpark22000("1", "j123@aaa.com", 2, 11)); - RECORDS_SPARK_22000.add(new RecordSpark22000("2", "j123@aaa.com", 3, 12)); - RECORDS_SPARK_22000.add(new RecordSpark22000("3", "j123@aaa.com", 4, 13)); - RECORDS_SPARK_22000.add(new RecordSpark22000("4", "j123@aaa.com", 5, 14)); + for (long idx = 0 ; idx < 5 ; idx++) { + Row row = createRecordSpark22000Row(idx); + ROWS_SPARK_22000.add(row); + RECORDS_SPARK_22000.add(createRecordSpark22000(row)); + } } @Test public void testSpark22000() { - // Here we try to convert the type of 'ref' field, from integer to string. + // Here we try to convert the fields, from any types to string. // Before applying SPARK-22000, Spark called toString() against variable which type might be primitive. // SPARK-22000 it calls String.valueOf() which finally calls toString() but handles boxing // if the type is primitive. Encoder encoder = Encoders.bean(RecordSpark22000.class); - Dataset dataset = spark - .read() - .format("csv") - .option("header", "true") - .option("mode", "DROPMALFORMED") - .schema("ref int, userId string, x int, y int") - .load("src/test/resources/test-data/spark-22000.csv") - .as(encoder); + StructType schema = new StructType() + .add("shortField", DataTypes.ShortType) + .add("intField", DataTypes.IntegerType) + .add("longField", DataTypes.LongType) + .add("floatField", DataTypes.FloatType) + .add("doubleField", DataTypes.DoubleType) + .add("stringField", DataTypes.StringType) + .add("booleanField", DataTypes.BooleanType) + .add("timestampField", DataTypes.TimestampType); + + Dataset dataFrame = spark.createDataFrame(ROWS_SPARK_22000, schema); + Dataset dataset = dataFrame.as(encoder); List records = dataset.collectAsList(); - Assert.assertEquals(records, RECORDS_SPARK_22000); + Assert.assertEquals(RECORDS_SPARK_22000, records); } public static class ArrayRecord { @@ -285,50 +327,79 @@ public String toString() { } public static class RecordSpark22000 { - private String ref; - private String userId; - private int x; - private int y; + private String shortField; + private String intField; + private String longField; + private String floatField; + private String doubleField; + private String stringField; + private String booleanField; + private String timestampField; public RecordSpark22000() { } - RecordSpark22000(String ref, String userId, int x, int y) { - this.ref = ref; - this.userId = userId; - this.x = x; - this.y = y; + public String getShortField() { + return shortField; + } + + public void setShortField(String shortField) { + this.shortField = shortField; + } + + public String getIntField() { + return intField; + } + + public void setIntField(String intField) { + this.intField = intField; + } + + public String getLongField() { + return longField; + } + + public void setLongField(String longField) { + this.longField = longField; + } + + public String getFloatField() { + return floatField; + } + + public void setFloatField(String floatField) { + this.floatField = floatField; } - public String getRef() { - return ref; + public String getDoubleField() { + return doubleField; } - public void setRef(String ref) { - this.ref = ref; + public void setDoubleField(String doubleField) { + this.doubleField = doubleField; } - public String getUserId() { - return userId; + public String getStringField() { + return stringField; } - public void setUserId(String userId) { - this.userId = userId; + public void setStringField(String stringField) { + this.stringField = stringField; } - public int getX() { - return x; + public String getBooleanField() { + return booleanField; } - public void setX(int x) { - this.x = x; + public void setBooleanField(String booleanField) { + this.booleanField = booleanField; } - public int getY() { - return y; + public String getTimestampField() { + return timestampField; } - public void setY(int y) { - this.y = y; + public void setTimestampField(String timestampField) { + this.timestampField = timestampField; } @Override @@ -336,20 +407,34 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; RecordSpark22000 that = (RecordSpark22000) o; - return x == that.x && - y == that.y && - Objects.equals(ref, that.ref) && - Objects.equals(userId, that.userId); + return Objects.equals(shortField, that.shortField) && + Objects.equals(intField, that.intField) && + Objects.equals(longField, that.longField) && + Objects.equals(floatField, that.floatField) && + Objects.equals(doubleField, that.doubleField) && + Objects.equals(stringField, that.stringField) && + Objects.equals(booleanField, that.booleanField) && + Objects.equals(timestampField, that.timestampField); } @Override public int hashCode() { - return Objects.hash(ref, userId, x, y); + return Objects.hash(shortField, intField, longField, floatField, doubleField, stringField, + booleanField, timestampField); } @Override public String toString() { - return String.format("ref='%s', userId='%s', x=%d, y=%d", ref, userId, x, y); + return com.google.common.base.Objects.toStringHelper(this) + .add("shortField", shortField) + .add("intField", intField) + .add("longField", longField) + .add("floatField", floatField) + .add("doubleField", doubleField) + .add("stringField", stringField) + .add("booleanField", booleanField) + .add("timestampField", timestampField) + .toString(); } } } diff --git a/sql/core/src/test/resources/test-data/spark-22000.csv b/sql/core/src/test/resources/test-data/spark-22000.csv deleted file mode 100755 index 06deb6f293352..0000000000000 --- a/sql/core/src/test/resources/test-data/spark-22000.csv +++ /dev/null @@ -1,5 +0,0 @@ -ref,userId,x,y -1,j123@aaa.com,2,11 -2,j123@aaa.com,3,12 -3,j123@aaa.com,4,13 -4,j123@aaa.com,5,14 \ No newline at end of file From d869bbafe9c499c09b994331f280d53cc52fceb9 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 22 Feb 2019 06:13:33 +0900 Subject: [PATCH 03/13] Address review comments from srowen, as well as fix style --- .../sql/JavaBeanDeserializationSuite.java | 104 ++++++++++-------- 1 file changed, 57 insertions(+), 47 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 09426e1cf0d73..1bc3e6d2155da 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -21,14 +21,9 @@ import java.util.*; import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverterSuite; -import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.types.UTF8String; import org.junit.*; import org.apache.spark.sql.Dataset; @@ -124,48 +119,20 @@ public void testBeanWithMapFieldsDeserialization() { Assert.assertEquals(records, MAP_RECORDS); } - private static final List ROWS_SPARK_22000 = new ArrayList<>(); - private static final List RECORDS_SPARK_22000 = new ArrayList<>(); - - private static Row createRecordSpark22000Row(Long index) { - Object[] values = new Object[] { - index.shortValue(), - index.intValue(), - index, - index.floatValue(), - index.doubleValue(), - String.valueOf(index), - index % 2 == 0, - new java.sql.Timestamp(System.currentTimeMillis()) - }; - return new GenericRow(values); - } - - private static RecordSpark22000 createRecordSpark22000(Row recordRow) { - RecordSpark22000 record = new RecordSpark22000(); - record.setShortField(String.valueOf(recordRow.getShort(0))); - record.setIntField(String.valueOf(recordRow.getInt(1))); - record.setLongField(String.valueOf(recordRow.getLong(2))); - record.setFloatField(String.valueOf(recordRow.getFloat(3))); - record.setDoubleField(String.valueOf(recordRow.getDouble(4))); - record.setStringField(recordRow.getString(5)); - record.setBooleanField(String.valueOf(recordRow.getBoolean(6))); - record.setTimestampField(String.valueOf(recordRow.getTimestamp(7).getTime() * 1000)); - return record; - } + @Test + public void testSpark22000() { + List inputRows = new ArrayList<>(); + List expectedRecords = new ArrayList<>(); - static { for (long idx = 0 ; idx < 5 ; idx++) { Row row = createRecordSpark22000Row(idx); - ROWS_SPARK_22000.add(row); - RECORDS_SPARK_22000.add(createRecordSpark22000(row)); + inputRows.add(row); + expectedRecords.add(createRecordSpark22000(row)); } - } - @Test - public void testSpark22000() { // Here we try to convert the fields, from any types to string. - // Before applying SPARK-22000, Spark called toString() against variable which type might be primitive. + // Before applying SPARK-22000, Spark called toString() against variable which type might + // be primitive. // SPARK-22000 it calls String.valueOf() which finally calls toString() but handles boxing // if the type is primitive. Encoder encoder = Encoders.bean(RecordSpark22000.class); @@ -178,14 +145,46 @@ public void testSpark22000() { .add("doubleField", DataTypes.DoubleType) .add("stringField", DataTypes.StringType) .add("booleanField", DataTypes.BooleanType) - .add("timestampField", DataTypes.TimestampType); + .add("timestampField", DataTypes.TimestampType) + // explicitly setting nullable = true to make clear the intention + .add("nullIntField", DataTypes.IntegerType, true); - Dataset dataFrame = spark.createDataFrame(ROWS_SPARK_22000, schema); + Dataset dataFrame = spark.createDataFrame(inputRows, schema); Dataset dataset = dataFrame.as(encoder); List records = dataset.collectAsList(); - Assert.assertEquals(RECORDS_SPARK_22000, records); + Assert.assertEquals(records, records); + } + + private static Row createRecordSpark22000Row(Long index) { + Object[] values = new Object[] { + index.shortValue(), + index.intValue(), + index, + index.floatValue(), + index.doubleValue(), + String.valueOf(index), + index % 2 == 0, + new java.sql.Timestamp(System.currentTimeMillis()), + null + }; + return new GenericRow(values); + } + + private static RecordSpark22000 createRecordSpark22000(Row recordRow) { + RecordSpark22000 record = new RecordSpark22000(); + record.setShortField(String.valueOf(recordRow.getShort(0))); + record.setIntField(String.valueOf(recordRow.getInt(1))); + record.setLongField(String.valueOf(recordRow.getLong(2))); + record.setFloatField(String.valueOf(recordRow.getFloat(3))); + record.setDoubleField(String.valueOf(recordRow.getDouble(4))); + record.setStringField(recordRow.getString(5)); + record.setBooleanField(String.valueOf(recordRow.getBoolean(6))); + record.setTimestampField(String.valueOf(recordRow.getTimestamp(7).getTime() * 1000)); + // This would figure out that null value will not become "null". + record.setNullIntField(null); + return record; } public static class ArrayRecord { @@ -326,7 +325,7 @@ public String toString() { } } - public static class RecordSpark22000 { + public final static class RecordSpark22000 { private String shortField; private String intField; private String longField; @@ -335,6 +334,7 @@ public static class RecordSpark22000 { private String stringField; private String booleanField; private String timestampField; + private String nullIntField; public RecordSpark22000() { } @@ -402,6 +402,14 @@ public void setTimestampField(String timestampField) { this.timestampField = timestampField; } + public String getNullIntField() { + return nullIntField; + } + + public void setNullIntField(String nullIntField) { + this.nullIntField = nullIntField; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -414,13 +422,14 @@ public boolean equals(Object o) { Objects.equals(doubleField, that.doubleField) && Objects.equals(stringField, that.stringField) && Objects.equals(booleanField, that.booleanField) && - Objects.equals(timestampField, that.timestampField); + Objects.equals(timestampField, that.timestampField) && + Objects.equals(nullIntField, that.nullIntField); } @Override public int hashCode() { return Objects.hash(shortField, intField, longField, floatField, doubleField, stringField, - booleanField, timestampField); + booleanField, timestampField, nullIntField); } @Override @@ -434,6 +443,7 @@ public String toString() { .add("stringField", stringField) .add("booleanField", booleanField) .add("timestampField", timestampField) + .add("nullIntField", nullIntField) .toString(); } } From 60ca0889f14db3ad3229340c7a8ba1b361ac2c0e Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Sat, 23 Feb 2019 07:17:54 +0900 Subject: [PATCH 04/13] Apply cloud-fan's suggestion: address Upcast in JavaTypeInference --- .../sql/catalyst/JavaTypeInference.scala | 103 +++++++++++++----- .../spark/sql/catalyst/ScalaReflection.scala | 14 ++- 2 files changed, 89 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index b28570e3fff91..d063914d48f89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -26,6 +26,7 @@ import scala.language.existentials import com.google.common.reflect.TypeToken +import org.apache.spark.sql.catalyst.ScalaReflection.upCastToExpectedType import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -194,13 +195,31 @@ object JavaTypeInference { */ def deserializerFor(beanClass: Class[_]): Expression = { val typeToken = TypeToken.of(beanClass) - deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1)) + val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil + val (dataType, nullable) = inferDataType(typeToken) + + // Assumes we are deserializing the first column of a row. + val input = upCastToExpectedType( + GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) + + val expr = deserializerFor(typeToken, input, walkedTypePath) + if (nullable) { + expr + } else { + AssertNotNull(expr, walkedTypePath) + } } - private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = { + private def deserializerFor( + typeToken: TypeToken[_], + path: Expression, + walkedTypePath: Seq[String]): Expression = { + /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = UnresolvedExtractValue(path, - expressions.Literal(part)) + def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { + val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => path @@ -211,8 +230,7 @@ object JavaTypeInference { c == classOf[java.lang.Double] || c == classOf[java.lang.Float] || c == classOf[java.lang.Byte] || - c == classOf[java.lang.Boolean] || - c == classOf[java.lang.String] => + c == classOf[java.lang.Boolean] => StaticInvoke( c, ObjectType(c), @@ -236,6 +254,9 @@ object JavaTypeInference { path :: Nil, returnNullable = false) + case c if c == classOf[java.lang.String] => + Invoke(path, "toString", ObjectType(classOf[String])) + case c if c == classOf[java.math.BigDecimal] => Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) @@ -252,32 +273,60 @@ object JavaTypeInference { case _ => None } - primitiveMethod.map { method => - Invoke(path, method, ObjectType(c)) - }.getOrElse { - Invoke( - MapObjects( - p => deserializerFor(typeToken.getComponentType, p), - path, - inferDataType(elementType)._1), - "array", - ObjectType(c)) + val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: + walkedTypePath + val (dataType, elementNullable) = inferDataType(elementType) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(typeToken.getComponentType, casted, newTypePath) + if (elementNullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } + } + + val arrayCls = MapObjects(mapFunction, path, dataType) + + if (elementNullable) { + Invoke(arrayCls, "array", ObjectType(c), returnNullable = false) + } else { + primitiveMethod.map { method => + Invoke(path, method, ObjectType(c), returnNullable = false) + }.getOrElse { + throw new IllegalStateException("expect primitive array element type") + } } case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - UnresolvedMapObjects( - p => deserializerFor(et, p), - path, - customCollectionCls = Some(c)) + val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +: + walkedTypePath + val (dataType, elementNullable) = inferDataType(et) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(et, casted, newTypePath) + if (elementNullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } + } + + UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) + val newTypePath = s"""- map key class: "${keyType.getType.getTypeName}", + | value class: "${valueType.getType.getTypeName}"""" +: + walkedTypePath val keyData = Invoke( UnresolvedMapObjects( - p => deserializerFor(keyType, p), + p => deserializerFor(keyType, p, newTypePath), MapKeys(path)), "array", ObjectType(classOf[Array[Any]])) @@ -285,7 +334,7 @@ object JavaTypeInference { val valueData = Invoke( UnresolvedMapObjects( - p => deserializerFor(valueType, p), + p => deserializerFor(valueType, p, newTypePath), MapValues(path)), "array", ObjectType(classOf[Array[Any]])) @@ -310,12 +359,16 @@ object JavaTypeInference { val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType - val (_, nullable) = inferDataType(fieldType) - val constructor = deserializerFor(fieldType, addToPath(fieldName)) + val (dataType, nullable) = inferDataType(fieldType) + val newTypePath = + s"""- field (class: "${fieldType.getType.getTypeName}", + |name: "$fieldName")""".stripMargin +: walkedTypePath + val constructor = deserializerFor(fieldType, addToPath(fieldName, dataType, newTypePath), + newTypePath) val setter = if (nullable) { constructor } else { - AssertNotNull(constructor, Seq("currently no type path record in java")) + AssertNotNull(constructor, newTypePath) } p.getWriteMethod.getName -> setter }.toMap diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d5af91acd071a..52c21aa69eb68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -139,7 +139,7 @@ object ScalaReflection extends ScalaReflection { * This method help us "remember" the required data type by adding a `UpCast`. Note that we * only need to do this for leaf nodes. */ - private def upCastToExpectedType(expr: Expression, expected: DataType, + private[spark] def upCastToExpectedType(expr: Expression, expected: DataType, walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr case _: ArrayType => expr @@ -349,10 +349,18 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t + val classNameForKey = getClassNameFromType(keyType) + val classNameForValue = getClassNameFromType(valueType) + + val newTypePath = + s"""- map key class: "$classNameForKey", + | value class: "$classNameForValue"""".stripMargin +: + walkedTypePath + UnresolvedCatalystToExternalMap( path, - p => deserializerFor(keyType, p, walkedTypePath), - p => deserializerFor(valueType, p, walkedTypePath), + p => deserializerFor(keyType, p, newTypePath), + p => deserializerFor(valueType, p, newTypePath), mirror.runtimeClass(t.typeSymbol.asClass) ) From 9eeb3910889e3809c119cb1b00ba4898d3b0ab0d Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Sat, 23 Feb 2019 21:37:18 +0900 Subject: [PATCH 05/13] Fix java lint --- .../test/org/apache/spark/sql/JavaBeanDeserializationSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 1bc3e6d2155da..9e50f99a5d1ee 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -325,7 +325,7 @@ public String toString() { } } - public final static class RecordSpark22000 { + public static final class RecordSpark22000 { private String shortField; private String intField; private String longField; From d4c2060d41a2aa48bb0c3f6feeb32c9ca124cf85 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Mon, 25 Feb 2019 13:44:24 +0900 Subject: [PATCH 06/13] Refactor: extract duplicated codes in deserializerFor between ScalaReflection and JavaTypeInference --- .../catalyst/DeserializerBuildHelper.scala | 158 +++++++++++++++ .../sql/catalyst/JavaTypeInference.scala | 127 +++++------- .../spark/sql/catalyst/ScalaReflection.scala | 182 ++++++------------ 3 files changed, 263 insertions(+), 204 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala new file mode 100644 index 0000000000000..1545b6998e063 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +private[spark] object DeserializerBuildHelper { + /** Returns the current path with a sub-field extracted. */ + def addToPath( + path: Expression, + part: String, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal( + path: Expression, + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = GetStructField(path, ordinal) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + + def expressionWithNullSafety( + expr: Expression, + nullable: Boolean, + walkedTypePath: Seq[String]): Expression = { + if (nullable) { + expr + } else { + AssertNotNull(expr, walkedTypePath) + } + } + + def deserializerForWithNullSafety( + expr: Expression, + dataType: DataType, + nullable: Boolean, + walkedTypePath: Seq[String], + funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) + expressionWithNullSafety(newExpr, nullable, walkedTypePath) + } + + def deserializerForWithNullSafetyAndUpcast( + expr: Expression, + dataType: DataType, + nullable: Boolean, + walkedTypePath: Seq[String], + funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + val casted = upCastToExpectedType(expr, dataType, walkedTypePath) + deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, + funcForCreatingNewExpr) + } + + def createDeserializerForTypesSupportValueOf( + path: Expression, + clazz: Class[_]): Expression = { + StaticInvoke( + clazz, + ObjectType(clazz), + "valueOf", + path :: Nil, + returnNullable = false) + } + + def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = { + Invoke(path, "toString", ObjectType(classOf[java.lang.String]), + returnNullable = returnNullable) + } + + def createDeserializerForSqlDate(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + path :: Nil, + returnNullable = false) + } + + def createDeserializerForSqlTimestamp(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + path :: Nil, + returnNullable = false) + } + + def createDeserializerForJavaBigDecimal( + path: Expression, + returnNullable: Boolean): Expression = { + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = returnNullable) + } + + def createDeserializerForScalaBigDecimal( + path: Expression, + returnNullable: Boolean): Expression = { + Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = returnNullable) + } + + def createDeserializerForJavaBigInteger( + path: Expression, + returnNullable: Boolean): Expression = { + Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + returnNullable = returnNullable) + } + + def createDeserializerForScalaBigInt(path: Expression): Expression = { + Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + returnNullable = false) + } + + /** + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. + */ + private def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _: ArrayType => expr + case _: MapType => expr + case _ => UpCast(expr, expected, walkedTypePath) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index d063914d48f89..1e35f52bd25e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -26,8 +26,8 @@ import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.catalyst.ScalaReflection.upCastToExpectedType -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -199,28 +199,16 @@ object JavaTypeInference { val (dataType, nullable) = inferDataType(typeToken) // Assumes we are deserializing the first column of a row. - val input = upCastToExpectedType( - GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) - - val expr = deserializerFor(typeToken, input, walkedTypePath) - if (nullable) { - expr - } else { - AssertNotNull(expr, walkedTypePath) - } + deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, + nullable = nullable, walkedTypePath, (casted, walkedTypePath) => { + deserializerFor(typeToken, casted, walkedTypePath) + }) } private def deserializerFor( typeToken: TypeToken[_], path: Expression, walkedTypePath: Seq[String]): Expression = { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) - upCastToExpectedType(newPath, dataType, walkedTypePath) - } - typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => path @@ -231,60 +219,36 @@ object JavaTypeInference { c == classOf[java.lang.Float] || c == classOf[java.lang.Byte] || c == classOf[java.lang.Boolean] => - StaticInvoke( - c, - ObjectType(c), - "valueOf", - path :: Nil, - returnNullable = false) + createDeserializerForTypesSupportValueOf(path, c) case c if c == classOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(c), - "toJavaDate", - path :: Nil, - returnNullable = false) + createDeserializerForSqlDate(path) case c if c == classOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(c), - "toJavaTimestamp", - path :: Nil, - returnNullable = false) + createDeserializerForSqlTimestamp(path) case c if c == classOf[java.lang.String] => - Invoke(path, "toString", ObjectType(classOf[String])) + createDeserializerForString(path, returnNullable = true) case c if c == classOf[java.math.BigDecimal] => - Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + createDeserializerForJavaBigDecimal(path, returnNullable = true) + + case c if c == classOf[java.math.BigInteger] => + createDeserializerForJavaBigInteger(path, returnNullable = true) case c if c.isArray => val elementType = c.getComponentType - val primitiveMethod = elementType match { - case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") - case c if c == java.lang.Byte.TYPE => Some("toByteArray") - case c if c == java.lang.Short.TYPE => Some("toShortArray") - case c if c == java.lang.Integer.TYPE => Some("toIntArray") - case c if c == java.lang.Long.TYPE => Some("toLongArray") - case c if c == java.lang.Float.TYPE => Some("toFloatArray") - case c if c == java.lang.Double.TYPE => Some("toDoubleArray") - case _ => None - } - val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: walkedTypePath val (dataType, elementNullable) = inferDataType(elementType) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. - val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(typeToken.getComponentType, casted, newTypePath) - if (elementNullable) { - converter - } else { - AssertNotNull(converter, newTypePath) - } + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) } val arrayCls = MapObjects(mapFunction, path, dataType) @@ -292,11 +256,18 @@ object JavaTypeInference { if (elementNullable) { Invoke(arrayCls, "array", ObjectType(c), returnNullable = false) } else { - primitiveMethod.map { method => - Invoke(path, method, ObjectType(c), returnNullable = false) - }.getOrElse { - throw new IllegalStateException("expect primitive array element type") + val primitiveMethod = elementType match { + case c if c == java.lang.Integer.TYPE => "toIntArray" + case c if c == java.lang.Long.TYPE => "toLongArray" + case c if c == java.lang.Double.TYPE => "toDoubleArray" + case c if c == java.lang.Float.TYPE => "toFloatArray" + case c if c == java.lang.Short.TYPE => "toShortArray" + case c if c == java.lang.Byte.TYPE => "toByteArray" + case c if c == java.lang.Boolean.TYPE => "toBooleanArray" + case other => throw new IllegalStateException("expect primitive array element type " + + "but got " + other) } + Invoke(path, primitiveMethod, ObjectType(c), returnNullable = false) } case c if listType.isAssignableFrom(typeToken) => @@ -306,13 +277,12 @@ object JavaTypeInference { val (dataType, elementNullable) = inferDataType(et) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. - val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(et, casted, newTypePath) - if (elementNullable) { - converter - } else { - AssertNotNull(converter, newTypePath) - } + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(et, casted, typePath)) } UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) @@ -347,12 +317,9 @@ object JavaTypeInference { returnNullable = false) case other if other.isEnum => - StaticInvoke( - other, - ObjectType(other), - "valueOf", - Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, - returnNullable = false) + createDeserializerForTypesSupportValueOf( + createDeserializerForString(path, returnNullable = false), + other) case other => val properties = getJavaBeanReadableAndWritableProperties(other) @@ -363,13 +330,13 @@ object JavaTypeInference { val newTypePath = s"""- field (class: "${fieldType.getType.getTypeName}", |name: "$fieldName")""".stripMargin +: walkedTypePath - val constructor = deserializerFor(fieldType, addToPath(fieldName, dataType, newTypePath), - newTypePath) - val setter = if (nullable) { - constructor - } else { - AssertNotNull(constructor, newTypePath) - } + val setter = deserializerForWithNullSafety( + path, + dataType, + nullable = nullable, + newTypePath, + (expr, typePath) => deserializerFor(fieldType, + addToPath(expr, fieldName, dataType, typePath), typePath)) p.getWriteMethod.getName -> setter }.toMap diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 52c21aa69eb68..f1f560585c69c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.catalyst -import java.lang.reflect.Constructor - -import scala.util.Properties - import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal +import org.apache.spark.sql.catalyst.expressions.{Expression, _} import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -128,25 +125,6 @@ object ScalaReflection extends ScalaReflection { case _ => false } - /** - * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff - * and lost the required data type, which may lead to runtime error if the real type doesn't - * match the encoder's schema. - * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type - * is [a: int, b: long], then we will hit runtime error and say that we can't construct class - * `Data` with int and long, because we lost the information that `b` should be a string. - * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * only need to do this for leaf nodes. - */ - private[spark] def upCastToExpectedType(expr: Expression, expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { - case _: StructType => expr - case _: ArrayType => expr - case _: MapType => expr - case _ => UpCast(expr, expected, walkedTypePath) - } - /** * Returns an expression that can be used to deserialize a Spark SQL representation to an object * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of @@ -162,15 +140,9 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. - val input = upCastToExpectedType( - GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) - - val expr = deserializerFor(tpe, input, walkedTypePath) - if (nullable) { - expr - } else { - AssertNotNull(expr, walkedTypePath) - } + deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, + nullable = nullable, walkedTypePath, + (casted, typePath) => deserializerFor(tpe, casted, typePath)) } /** @@ -185,22 +157,6 @@ object ScalaReflection extends ScalaReflection { tpe: `Type`, path: Expression, walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) - upCastToExpectedType(newPath, dataType, walkedTypePath) - } - - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal( - ordinal: Int, - dataType: DataType, - walkedTypePath: Seq[String]): Expression = { - val newPath = GetStructField(path, ordinal) - upCastToExpectedType(newPath, dataType, walkedTypePath) - } - tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path @@ -211,73 +167,53 @@ object ScalaReflection extends ScalaReflection { WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => - val boxedType = classOf[java.lang.Integer] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Integer]) case t if t <:< localTypeOf[java.lang.Long] => - val boxedType = classOf[java.lang.Long] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Long]) case t if t <:< localTypeOf[java.lang.Double] => - val boxedType = classOf[java.lang.Double] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Double]) case t if t <:< localTypeOf[java.lang.Float] => - val boxedType = classOf[java.lang.Float] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Float]) case t if t <:< localTypeOf[java.lang.Short] => - val boxedType = classOf[java.lang.Short] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Short]) case t if t <:< localTypeOf[java.lang.Byte] => - val boxedType = classOf[java.lang.Byte] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Byte]) case t if t <:< localTypeOf[java.lang.Boolean] => - val boxedType = classOf[java.lang.Boolean] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Boolean]) case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - path :: Nil, - returnNullable = false) + createDeserializerForSqlDate(path) case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - path :: Nil, - returnNullable = false) + createDeserializerForSqlTimestamp(path) case t if t <:< localTypeOf[java.lang.String] => - Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) + createDeserializerForString(path, returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), - returnNullable = false) + createDeserializerForJavaBigDecimal(path, returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) + createDeserializerForScalaBigDecimal(path, returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), - returnNullable = false) + createDeserializerForJavaBigInteger(path, returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), - returnNullable = false) + createDeserializerForScalaBigInt(path) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -287,13 +223,12 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. - val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, casted, newTypePath) - if (elementNullable) { - converter - } else { - AssertNotNull(converter, newTypePath) - } + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(elementType, casted, typePath)) } val arrayData = UnresolvedMapObjects(mapFunction, path) @@ -326,14 +261,12 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, casted, newTypePath) - if (elementNullable) { - converter - } else { - AssertNotNull(converter, newTypePath) - } + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(elementType, casted, typePath)) } val companion = t.dealias.typeSymbol.companion.typeSignature @@ -346,7 +279,6 @@ object ScalaReflection extends ScalaReflection { UnresolvedMapObjects(mapFunction, path, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => - // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t val classNameForKey = getClassNameFromType(keyType) @@ -391,24 +323,26 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - // For tuples, we based grab the inner fields by ordinal instead of name. - val constructor = if (cls.getName startsWith "scala.Tuple") { - deserializerFor( - fieldType, - addToPathOrdinal(i, dataType, newTypePath), - newTypePath) - } else { - deserializerFor( - fieldType, - addToPath(fieldName, dataType, newTypePath), - newTypePath) - } - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + // For tuples, we based grab the inner fields by ordinal instead of name. + deserializerForWithNullSafety( + path, + dataType, + nullable = nullable, + newTypePath, + (expr, typePath) => { + if (cls.getName startsWith "scala.Tuple") { + deserializerFor( + fieldType, + addToPathOrdinal(expr, i, dataType, typePath), + newTypePath) + } else { + deserializerFor( + fieldType, + addToPath(expr, fieldName, dataType, typePath), + newTypePath) + } + }) } val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) From 9c040e2dff215168633bd2e27b3695dbe0a00383 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Mon, 25 Feb 2019 16:11:50 +0900 Subject: [PATCH 07/13] Add new test: upcast doesn't help making Dataset be compatible with Java Bean Encoder --- .../sql/JavaBeanDeserializationSuite.java | 55 +++++++++++++++++-- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 9e50f99a5d1ee..49ff522cee8e8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -20,15 +20,12 @@ import java.io.Serializable; import java.util.*; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.*; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; import org.junit.*; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; import org.apache.spark.sql.test.TestSparkSession; public class JavaBeanDeserializationSuite implements Serializable { @@ -157,6 +154,36 @@ public void testSpark22000() { Assert.assertEquals(records, records); } + @Test + public void testSpark22000FailToUpcast() { + List inputRows = new ArrayList<>(); + for (long idx = 0 ; idx < 5 ; idx++) { + Row row = createRecordSpark22000FailToUpcastRow(idx); + inputRows.add(row); + } + + // Here we try to convert the fields, from string type to int, which upcast doesn't help. + Encoder encoder = + Encoders.bean(RecordSpark22000FailToUpcast.class); + + StructType schema = new StructType().add("id", DataTypes.StringType); + + Dataset dataFrame = spark.createDataFrame(inputRows, schema); + + try { + dataFrame.as(encoder).collect(); + Assert.fail("Expected AnalysisException, but passed."); + } catch (Throwable e) { + // Here we need to handle weird case: compiler complains AnalysisException never be thrown + // in try statement, but it can be thrown actually. Maybe Scala-Java interop issue? + if (e instanceof AnalysisException) { + Assert.assertTrue(e.getMessage().contains("Cannot up cast ")); + } else { + throw e; + } + } + } + private static Row createRecordSpark22000Row(Long index) { Object[] values = new Object[] { index.shortValue(), @@ -187,6 +214,11 @@ private static RecordSpark22000 createRecordSpark22000(Row recordRow) { return record; } + private static Row createRecordSpark22000FailToUpcastRow(Long index) { + Object[] values = new Object[] { String.valueOf(index) }; + return new GenericRow(values); + } + public static class ArrayRecord { private int id; @@ -447,4 +479,19 @@ public String toString() { .toString(); } } + + public static final class RecordSpark22000FailToUpcast { + private Integer id; + + public RecordSpark22000FailToUpcast() { + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + } } From 2007452fc19bc5e273dcb71d92386bec66b0bcb4 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Mon, 25 Feb 2019 21:15:09 +0900 Subject: [PATCH 08/13] Address review comments from cloud-fan --- .../sql/catalyst/JavaTypeInference.scala | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1e35f52bd25e5..fd020df4399cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -241,20 +241,19 @@ object JavaTypeInference { val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: walkedTypePath val (dataType, elementNullable) = inferDataType(elementType) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - dataType, - nullable = elementNullable, - newTypePath, - (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) - } - - val arrayCls = MapObjects(mapFunction, path, dataType) if (elementNullable) { - Invoke(arrayCls, "array", ObjectType(c), returnNullable = false) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) + } + val arrayData = MapObjects(mapFunction, path, dataType) + Invoke(arrayData, "array", ObjectType(c), returnNullable = false) } else { val primitiveMethod = elementType match { case c if c == java.lang.Integer.TYPE => "toIntArray" @@ -289,9 +288,8 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = s"""- map key class: "${keyType.getType.getTypeName}", - | value class: "${valueType.getType.getTypeName}"""" +: - walkedTypePath + val newTypePath = (s"- map key class: ${keyType.getType.getTypeName}" + + s", value class: ${valueType.getType.getTypeName}") +: walkedTypePath val keyData = Invoke( @@ -327,9 +325,8 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = - s"""- field (class: "${fieldType.getType.getTypeName}", - |name: "$fieldName")""".stripMargin +: walkedTypePath + val newTypePath = (s"- field (class: ${fieldType.getType.getTypeName}" + + s", name: $fieldName)") +: walkedTypePath val setter = deserializerForWithNullSafety( path, dataType, From 4d564abbe39aaa7c61d2dd12b45cd46c557ce889 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Mon, 25 Feb 2019 21:18:55 +0900 Subject: [PATCH 09/13] Address weird style from ScalaReflection as well --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f1f560585c69c..34f8c5764598e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -284,9 +284,7 @@ object ScalaReflection extends ScalaReflection { val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) - val newTypePath = - s"""- map key class: "$classNameForKey", - | value class: "$classNameForValue"""".stripMargin +: + val newTypePath = s"- map key class: $classNameForKey, value class: $classNameForValue" +: walkedTypePath UnresolvedCatalystToExternalMap( @@ -322,7 +320,7 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + val newTypePath = s"- field (class: $clsName, name: $fieldName)" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. deserializerForWithNullSafety( From e01bfe645f5d27e1cae1f703fdd3a76bb168291e Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 26 Feb 2019 07:45:13 +0900 Subject: [PATCH 10/13] Fix build failure due to slightly change of walked type path representation --- .../org/apache/spark/sql/catalyst/JavaTypeInference.scala | 8 ++++---- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index fd020df4399cc..2ede5a3267bb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -288,8 +288,8 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = (s"- map key class: ${keyType.getType.getTypeName}" + - s", value class: ${valueType.getType.getTypeName}") +: walkedTypePath + val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" + + s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath val keyData = Invoke( @@ -325,8 +325,8 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = (s"- field (class: ${fieldType.getType.getTypeName}" + - s", name: $fieldName)") +: walkedTypePath + val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" + + s""", name: "$fieldName")""") +: walkedTypePath val setter = deserializerForWithNullSafety( path, dataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 34f8c5764598e..7403686ec3aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -284,8 +284,8 @@ object ScalaReflection extends ScalaReflection { val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) - val newTypePath = s"- map key class: $classNameForKey, value class: $classNameForValue" +: - walkedTypePath + val newTypePath = (s"""- map key class: "${classNameForKey}"""" + + s""", value class: "${classNameForValue}"""") +: walkedTypePath UnresolvedCatalystToExternalMap( path, @@ -320,7 +320,8 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newTypePath = s"- field (class: $clsName, name: $fieldName)" +: walkedTypePath + val newTypePath = (s"""- field (class: "$clsName", """ + + s"""name: "$fieldName")""") +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. deserializerForWithNullSafety( From 8cbad26298bd57c852e3847ad21b1fd64df5cfb2 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 26 Feb 2019 14:31:43 +0900 Subject: [PATCH 11/13] Address review comments from cloud-fan & maropu --- .../catalyst/DeserializerBuildHelper.scala | 24 +++++++-------- .../sql/catalyst/JavaTypeInference.scala | 29 ++++++++++++------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 1545b6998e063..e71955ab4e757 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -private[spark] object DeserializerBuildHelper { +object DeserializerBuildHelper { /** Returns the current path with a sub-field extracted. */ def addToPath( path: Expression, @@ -44,17 +44,6 @@ private[spark] object DeserializerBuildHelper { upCastToExpectedType(newPath, dataType, walkedTypePath) } - def expressionWithNullSafety( - expr: Expression, - nullable: Boolean, - walkedTypePath: Seq[String]): Expression = { - if (nullable) { - expr - } else { - AssertNotNull(expr, walkedTypePath) - } - } - def deserializerForWithNullSafety( expr: Expression, dataType: DataType, @@ -76,6 +65,17 @@ private[spark] object DeserializerBuildHelper { funcForCreatingNewExpr) } + private def expressionWithNullSafety( + expr: Expression, + nullable: Boolean, + walkedTypePath: Seq[String]): Expression = { + if (nullable) { + expr + } else { + AssertNotNull(expr, walkedTypePath) + } + } + def createDeserializerForTypesSupportValueOf( path: Expression, clazz: Class[_]): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 2ede5a3267bb6..0f7dcba355062 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -256,17 +256,26 @@ object JavaTypeInference { Invoke(arrayData, "array", ObjectType(c), returnNullable = false) } else { val primitiveMethod = elementType match { - case c if c == java.lang.Integer.TYPE => "toIntArray" - case c if c == java.lang.Long.TYPE => "toLongArray" - case c if c == java.lang.Double.TYPE => "toDoubleArray" - case c if c == java.lang.Float.TYPE => "toFloatArray" - case c if c == java.lang.Short.TYPE => "toShortArray" - case c if c == java.lang.Byte.TYPE => "toByteArray" - case c if c == java.lang.Boolean.TYPE => "toBooleanArray" - case other => throw new IllegalStateException("expect primitive array element type " + - "but got " + other) + case c if c == java.lang.Integer.TYPE => Some("toIntArray") + case c if c == java.lang.Long.TYPE => Some("toLongArray") + case c if c == java.lang.Double.TYPE => Some("toDoubleArray") + case c if c == java.lang.Float.TYPE => Some("toFloatArray") + case c if c == java.lang.Short.TYPE => Some("toShortArray") + case c if c == java.lang.Byte.TYPE => Some("toByteArray") + case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") + case _ => None + } + primitiveMethod.map { method => + Invoke(path, method, ObjectType(c)) + }.getOrElse { + Invoke( + MapObjects( + p => deserializerFor(typeToken.getComponentType, p, newTypePath), + path, + inferDataType(elementType)._1), + "array", + ObjectType(c)) } - Invoke(path, primitiveMethod, ObjectType(c), returnNullable = false) } case c if listType.isAssignableFrom(typeToken) => From 7ff01dbd27ffa90f118f317a8d44eb3a9d085995 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 26 Feb 2019 22:28:59 +0900 Subject: [PATCH 12/13] Revert back wrong change and sync again with ScalaReflection implementation --- .../sql/catalyst/JavaTypeInference.scala | 51 ++++++++----------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 0f7dcba355062..b791483737920 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -241,41 +241,32 @@ object JavaTypeInference { val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: walkedTypePath val (dataType, elementNullable) = inferDataType(elementType) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) + } + val arrayData = UnresolvedMapObjects(mapFunction, path) if (elementNullable) { - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - dataType, - nullable = elementNullable, - newTypePath, - (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) - } - val arrayData = MapObjects(mapFunction, path, dataType) - Invoke(arrayData, "array", ObjectType(c), returnNullable = false) + Invoke(arrayData, "array", ObjectType(c)) } else { val primitiveMethod = elementType match { - case c if c == java.lang.Integer.TYPE => Some("toIntArray") - case c if c == java.lang.Long.TYPE => Some("toLongArray") - case c if c == java.lang.Double.TYPE => Some("toDoubleArray") - case c if c == java.lang.Float.TYPE => Some("toFloatArray") - case c if c == java.lang.Short.TYPE => Some("toShortArray") - case c if c == java.lang.Byte.TYPE => Some("toByteArray") - case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") - case _ => None - } - primitiveMethod.map { method => - Invoke(path, method, ObjectType(c)) - }.getOrElse { - Invoke( - MapObjects( - p => deserializerFor(typeToken.getComponentType, p, newTypePath), - path, - inferDataType(elementType)._1), - "array", - ObjectType(c)) + case c if c == java.lang.Integer.TYPE => "toIntArray" + case c if c == java.lang.Long.TYPE => "toLongArray" + case c if c == java.lang.Double.TYPE => "toDoubleArray" + case c if c == java.lang.Float.TYPE => "toFloatArray" + case c if c == java.lang.Short.TYPE => "toShortArray" + case c if c == java.lang.Byte.TYPE => "toByteArray" + case c if c == java.lang.Boolean.TYPE => "toBooleanArray" + case other => throw new IllegalStateException("expect primitive array element type " + + "but got " + other) } + Invoke(arrayData, primitiveMethod, ObjectType(c)) } case c if listType.isAssignableFrom(typeToken) => From 24a1b195bad49e05353b80c85d30e32e8c898a52 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 26 Feb 2019 23:16:33 +0900 Subject: [PATCH 13/13] Apply preferred code style to both JavaTypeInference and ScalaReflection --- .../sql/catalyst/JavaTypeInference.scala | 27 +++++++++---------- .../spark/sql/catalyst/ScalaReflection.scala | 26 ++++++++---------- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index b791483737920..dafa87839ec6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -252,22 +252,19 @@ object JavaTypeInference { } val arrayData = UnresolvedMapObjects(mapFunction, path) - if (elementNullable) { - Invoke(arrayData, "array", ObjectType(c)) - } else { - val primitiveMethod = elementType match { - case c if c == java.lang.Integer.TYPE => "toIntArray" - case c if c == java.lang.Long.TYPE => "toLongArray" - case c if c == java.lang.Double.TYPE => "toDoubleArray" - case c if c == java.lang.Float.TYPE => "toFloatArray" - case c if c == java.lang.Short.TYPE => "toShortArray" - case c if c == java.lang.Byte.TYPE => "toByteArray" - case c if c == java.lang.Boolean.TYPE => "toBooleanArray" - case other => throw new IllegalStateException("expect primitive array element type " + - "but got " + other) - } - Invoke(arrayData, primitiveMethod, ObjectType(c)) + + val methodName = elementType match { + case c if c == java.lang.Integer.TYPE => "toIntArray" + case c if c == java.lang.Long.TYPE => "toLongArray" + case c if c == java.lang.Double.TYPE => "toDoubleArray" + case c if c == java.lang.Float.TYPE => "toFloatArray" + case c if c == java.lang.Short.TYPE => "toShortArray" + case c if c == java.lang.Byte.TYPE => "toByteArray" + case c if c == java.lang.Boolean.TYPE => "toBooleanArray" + // non-primitive + case _ => "array" } + Invoke(arrayData, methodName, ObjectType(c)) case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7403686ec3aa6..741cba80640b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -234,22 +234,18 @@ object ScalaReflection extends ScalaReflection { val arrayData = UnresolvedMapObjects(mapFunction, path) val arrayCls = arrayClassFor(elementType) - if (elementNullable) { - Invoke(arrayData, "array", arrayCls, returnNullable = false) - } else { - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => "toIntArray" - case t if t <:< definitions.LongTpe => "toLongArray" - case t if t <:< definitions.DoubleTpe => "toDoubleArray" - case t if t <:< definitions.FloatTpe => "toFloatArray" - case t if t <:< definitions.ShortTpe => "toShortArray" - case t if t <:< definitions.ByteTpe => "toByteArray" - case t if t <:< definitions.BooleanTpe => "toBooleanArray" - case other => throw new IllegalStateException("expect primitive array element type " + - "but got " + other) - } - Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) + val methodName = elementType match { + case t if t <:< definitions.IntTpe => "toIntArray" + case t if t <:< definitions.LongTpe => "toLongArray" + case t if t <:< definitions.DoubleTpe => "toDoubleArray" + case t if t <:< definitions.FloatTpe => "toFloatArray" + case t if t <:< definitions.ShortTpe => "toShortArray" + case t if t <:< definitions.ByteTpe => "toByteArray" + case t if t <:< definitions.BooleanTpe => "toBooleanArray" + // non-primitive + case _ => "array" } + Invoke(arrayData, methodName, arrayCls, returnNullable = false) // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.