diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala index 96faebc5a8687..aadd7f531cdc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala @@ -94,7 +94,7 @@ case class HiveTableScan( (_: Any, partitionKeys: Array[String]) => { val value = partitionKeys(ordinal) val dataType = relation.partitionKeys(ordinal).dataType - castFromString(value, dataType) + unwrapHiveData(castFromString(value, dataType)) } } else { val ref = objectInspector.getAllStructFieldRefs @@ -102,12 +102,19 @@ case class HiveTableScan( .getOrElse(sys.error(s"Can't find attribute $a")) (row: Any, _: Array[String]) => { val data = objectInspector.getStructFieldData(row, ref) - unwrapData(data, ref.getFieldObjectInspector) + unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector)) } } } } + private def unwrapHiveData(value: Any) = value match { + case maybeNull: String if maybeNull.toLowerCase == "null" => null + case varchar: HiveVarchar => varchar.getValue + case decimal: HiveDecimal => BigDecimal(decimal.bigDecimalValue) + case other => other + } + private def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) } @@ -143,20 +150,34 @@ case class HiveTableScan( } def execute() = { - inputRdd.map { row => - val values = row match { - case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) => - attributeFunctions.map(_(deserializedRow, partitionKeys)) - case deserializedRow: AnyRef => - attributeFunctions.map(_(deserializedRow, Array.empty)) + inputRdd.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val mutableRow = new GenericMutableRow(attributes.length) + val buffered = iterator.buffered + + (buffered.head match { + case Array(_, _) => + buffered.map { case Array(deserializedRow, partitionKeys: Array[String]) => + (deserializedRow, partitionKeys) + } + + case _ => + buffered.map { deserializedRow => + (deserializedRow, Array.empty[String]) + } + }).map { case (deserializedRow, partitionKeys: Array[String]) => + var i = 0 + + while (i < attributes.length) { + mutableRow(i) = attributeFunctions(i)(deserializedRow, partitionKeys) + i += 1 + } + + mutableRow: Row + } } - buildRow(values.map { - case n: String if n.toLowerCase == "null" => null - case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue - case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal => - BigDecimal(decimal.bigDecimalValue) - case other => other - }) } }