Skip to content

Commit

Permalink
renamed to ArrowConverters
Browse files Browse the repository at this point in the history
defined ArrowPayload and encapsulated Arrow classes in ArrowConverters

addressed some minor comments in code review

closes apache#21
  • Loading branch information
BryanCutler committed Jan 30, 2017
1 parent fadf588 commit 102cf3f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,50 @@

package org.apache.spark.sql

import java.io.ByteArrayOutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.language.implicitConversions

import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
import org.apache.arrow.vector._
import org.apache.arrow.vector.BaseValueVector.BaseMutator
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

object Arrow {
/**
* Intermediate data structure returned from Arrow conversions
*/
private[sql] abstract class ArrowPayload extends Iterator[ArrowRecordBatch]

/**
* Class that wraps an Arrow RootAllocator used in conversion
*/
private[sql] class ArrowConverters {
private val _allocator = new RootAllocator(Long.MaxValue)

private[sql] def allocator: RootAllocator = _allocator

private class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload {
private val iter = batches.iterator

override def next(): ArrowRecordBatch = iter.next()
override def hasNext: Boolean = iter.hasNext
}

def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = {
val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator)
new ArrowStaticPayload(batch)
}
}

private[sql] object ArrowConverters {

/**
* Map a Spark Dataset type to ArrowType.
Expand All @@ -49,7 +78,7 @@ object Arrow {
case BinaryType => ArrowType.Binary.INSTANCE
case DateType => ArrowType.Date.INSTANCE
case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND)
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
}
}

Expand Down Expand Up @@ -109,6 +138,25 @@ object Arrow {
}
new Schema(arrowFields.toList.asJava)
}

/**
* Write an ArrowPayload to a byte array
*/
private[sql] def payloadToByteArray(payload: ArrowPayload, schema: StructType): Array[Byte] = {
val arrowSchema = ArrowConverters.schemaToArrowSchema(schema)
val out = new ByteArrayOutputStream()
val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema)
try {
payload.foreach(writer.writeRecordBatch)
} catch {
case e: Exception =>
throw e
} finally {
writer.close()
payload.foreach(_.close())
}
out.toByteArray
}
}

private[sql] trait ColumnWriter {
Expand Down Expand Up @@ -255,7 +303,7 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
private[sql] class BinaryColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableVarBinaryVector
= new NullableVarBinaryVector("UTF8StringValue", allocator)
= new NullableVarBinaryVector("BinaryValue", allocator)
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
Expand All @@ -273,6 +321,7 @@ private[sql] class DateColumnWriter(allocator: BaseAllocator)

override protected def setNull(): Unit = valueMutator.setNull(count)
override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
// TODO: comment on diff btw value representations of date/timestamp
valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000)
}
}
Expand All @@ -286,6 +335,7 @@ private[sql] class TimeStampColumnWriter(allocator: BaseAllocator)
override protected def setNull(): Unit = valueMutator.setNull(count)

override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
// TODO: use microsecond timestamp when ARROW-477 is resolved
valueMutator.setSafe(count, row.getLong(ordinal) / 1000)
}
}
Expand Down
31 changes: 7 additions & 24 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@

package org.apache.spark.sql

import java.io.{ByteArrayOutputStream, CharArrayWriter}
import java.nio.channels.Channels
import java.io.CharArrayWriter

import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.ArrowRecordBatch
import org.apache.commons.lang3.StringUtils

import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
Expand Down Expand Up @@ -2375,14 +2371,12 @@ class Dataset[T] private[sql](
* @since 2.2.0
*/
@DeveloperApi
def collectAsArrow(rootAllocator: Option[RootAllocator] = None): ArrowRecordBatch = {
val allocator = rootAllocator.getOrElse(new RootAllocator(Long.MaxValue))
def collectAsArrow(converter: Option[ArrowConverters] = None): ArrowPayload = {
val cnvtr = converter.getOrElse(new ArrowConverters)
withNewExecutionId {
try {
val collectedRows = queryExecution.executedPlan.executeCollect()
val recordBatch = Arrow.internalRowsToArrowRecordBatch(
collectedRows, this.schema, allocator)
recordBatch
cnvtr.internalRowsToPayload(collectedRows, this.schema)
} catch {
case e: Exception =>
throw e
Expand Down Expand Up @@ -2763,22 +2757,11 @@ class Dataset[T] private[sql](
* Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark.
*/
private[sql] def collectAsArrowToPython(): Int = {
val recordBatch = collectAsArrow()
val arrowSchema = Arrow.schemaToArrowSchema(this.schema)
val out = new ByteArrayOutputStream()
try {
val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema)
writer.writeRecordBatch(recordBatch)
writer.close()
} catch {
case e: Exception =>
throw e
} finally {
recordBatch.close()
}
val payload = collectAsArrow()
val payloadBytes = ArrowConverters.payloadToByteArray(payload, this.schema)

withNewExecutionId {
PythonRDD.serveIterator(Iterator(out.toByteArray), "serve-Arrow")
PythonRDD.serveIterator(Iterator(payloadBytes), "serve-Arrow")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ package org.apache.spark.sql
import java.io.File
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.{Locale, TimeZone}
import java.util.Locale

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.file.json.JsonFileReader
import org.apache.arrow.vector.util.Validator

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.CalendarInterval


// NOTE - nullable type can be declared as Option[*] or java.lang.*
Expand All @@ -38,18 +36,19 @@ private[sql] case class FloatData(i: Int, a_f: Float, b_f: Option[Float])
private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double])


class ArrowSuite extends SharedSQLContext {
class ArrowConvertersSuite extends SharedSQLContext {
import testImplicits._

private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).getFile
}

test("collect to arrow record batch") {
val arrowRecordBatch = indexData.collectAsArrow()
assert(arrowRecordBatch.getLength > 0)
assert(arrowRecordBatch.getNodes.size() > 0)
arrowRecordBatch.close()
val arrowPayload = indexData.collectAsArrow()
assert(arrowPayload.nonEmpty)
arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getLength > 0))
arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getNodes.size() > 0))
arrowPayload.foreach(arrowRecordBatch => arrowRecordBatch.close())
}

test("standard type conversion") {
Expand Down Expand Up @@ -124,8 +123,9 @@ class ArrowSuite extends SharedSQLContext {
}

test("empty frame collect") {
val emptyBatch = spark.emptyDataFrame.collectAsArrow()
assert(emptyBatch.getLength == 0)
val arrowPayload = spark.emptyDataFrame.collectAsArrow()
assert(arrowPayload.nonEmpty)
arrowPayload.foreach(emptyBatch => assert(emptyBatch.getLength == 0))
}

test("unsupported types") {
Expand Down Expand Up @@ -163,17 +163,17 @@ class ArrowSuite extends SharedSQLContext {
private def collectAndValidate(df: DataFrame, arrowFile: String) {
val jsonFilePath = testFile(arrowFile)

val allocator = new RootAllocator(Integer.MAX_VALUE)
val jsonReader = new JsonFileReader(new File(jsonFilePath), allocator)
val converter = new ArrowConverters
val jsonReader = new JsonFileReader(new File(jsonFilePath), converter.allocator)

val arrowSchema = Arrow.schemaToArrowSchema(df.schema)
val arrowSchema = ArrowConverters.schemaToArrowSchema(df.schema)
val jsonSchema = jsonReader.start()
Validator.compareSchemas(arrowSchema, jsonSchema)

val arrowRecordBatch = df.collectAsArrow(Some(allocator))
val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator)
val arrowPayload = df.collectAsArrow(Some(converter))
val arrowRoot = new VectorSchemaRoot(arrowSchema, converter.allocator)
val vectorLoader = new VectorLoader(arrowRoot)
vectorLoader.load(arrowRecordBatch)
arrowPayload.foreach(vectorLoader.load)
val jsonRoot = jsonReader.read()

Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)
Expand Down

0 comments on commit 102cf3f

Please sign in to comment.