From 2a91a87d27e31497ebefd2ad80531c0017fafbc5 Mon Sep 17 00:00:00 2001 From: jeanlyn Date: Thu, 5 Feb 2015 11:25:36 +0800 Subject: [PATCH] add more test case and clean the code --- .../apache/spark/sql/hive/TableReader.scala | 18 ++++------- .../sql/hive/InsertIntoHiveTableSuite.scala | 30 +++++++++++++++++-- .../org/apache/spark/sql/hive/Shim12.scala | 7 +++-- .../org/apache/spark/sql/hive/Shim13.scala | 1 + 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c54ce886a7f35..a6cfd978b5f62 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -335,23 +335,17 @@ private[hive] object HadoopTableReader extends HiveInspectors { } } + /** + * when the soi and deserializer.getObjectInspector is equal, + * we will get `IdentityConverter`,which mean it won't convert the + * value when schema match + */ val partTblObjectInspectorConverter = ObjectInspectorConverters.getConverter( deserializer.getObjectInspector, soi) // Map each tuple to a row object iterator.map { value => - val raw = convertdeserializer match { - case Some(convert) => - if (deserializer.getObjectInspector.equals(convert.getObjectInspector)) { - deserializer.deserialize(value) - } - // If partition schema does not match table schema, update the row to match - else { - partTblObjectInspectorConverter.convert(deserializer.deserialize(value)) - } - case None => - deserializer.deserialize(value) - } + val raw = partTblObjectInspectorConverter.convert(deserializer.deserialize(value)) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index de87da1b7f82e..33e859427a4b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) +case class ThreeCloumntable(key: Int, value: String, key1: String) + class InsertIntoHiveTableSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) @@ -172,19 +174,43 @@ class InsertIntoHiveTableSuite extends QueryTest { sql("DROP TABLE hiveTableWithStructValue") } - + test("SPARK-5498:partition schema does not match table schema"){ val testData = TestHive.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))) testData.registerTempTable("testData") + + val testDatawithNull = TestHive.sparkContext.parallelize( + (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))) + val tmpDir = Files.createTempDir() sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") + + //test schema is the same sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.toSchemaRDD.collect.toSeq ) - sql("DROP TABLE table_with_partition") + // test difference type of field + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.toSchemaRDD.collect.toSeq + ) + + // add column to table + sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), + testDatawithNull.toSchemaRDD.collect.toSeq + ) + + // change column name to table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), + testData.toSchemaRDD.collect.toSeq + ) + + sql("DROP TABLE table_with_partition") } } diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 762eae57ff86d..4fe8e8621f7b8 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -243,8 +243,11 @@ private[hive] object HiveShim { } // make getConvertedOI compatible between 0.12.0 and 0.13.1 - def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, new java.lang.Boolean(true)) + def getConvertedOI(inputOI: ObjectInspector, + outputOI: ObjectInspector, + equalsCheck: java.lang.Boolean = + new java.lang.Boolean(true)): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, equalsCheck) } def prepareWritable(w: Writable): Writable = { diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index ff61479b33dda..55041417ecc17 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import java.util import java.util.{ArrayList => JArrayList} import java.util.Properties import java.rmi.server.UID