Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-185] Avoid unnecessary copying when simply projecting on fields (#…
Browse files Browse the repository at this point in the history
…187)

* [NSE-185] Avoid unnecessary copying when simply projecting on fields

* Avoid sharing buffers in output
  • Loading branch information
zhztheplayer authored Mar 26, 2021
1 parent 48a181c commit 155afa4
Showing 1 changed file with 207 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@

package com.intel.oap.expression

import java.util.Collections
import java.util
import java.util.Objects
import java.util.concurrent.TimeUnit

import com.google.common.collect.Lists
import com.intel.oap.expression.ColumnarConditionProjector.{FieldOptimizedProjector, FilterProjector, ProjectorWrapper}
import com.intel.oap.vectorized.ArrowWritableColumnVector

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

import org.apache.arrow.gandiva.evaluator._
import org.apache.arrow.gandiva.exceptions.GandivaException
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.gandiva.ipc.GandivaTypes
import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.memory.RootAllocator
Expand All @@ -42,9 +43,13 @@ import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.arrow.memory.ArrowBuf
import org.apache.arrow.util.AutoCloseables
import org.apache.arrow.vector.ValueVector

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.util.control.Breaks._

class ColumnarConditionProjector(
condPrepareList: (TreeNode, ArrowType),
Expand Down Expand Up @@ -115,7 +120,7 @@ class ColumnarConditionProjector(
false
}
val projector = if (skip == false) {
createProjector(projectionArrowSchema, projPrepareList, withCond)
createProjector(projectionArrowSchema, resultArrowSchema, projPrepareList, withCond)
} else {
null
}
Expand All @@ -134,24 +139,25 @@ class ColumnarConditionProjector(
}

def createProjector(
arrowSchema: Schema,
projectionSchema: Schema,
resultSchema: Schema,
prepareList: Seq[(ExpressionTree, ArrowType)],
withCond: Boolean): Projector = synchronized {
withCond: Boolean): ProjectorWrapper = synchronized {
if (projector != null) {
return projector
}
val fieldNodesList = prepareList.map(_._1).toList.asJava
try {
if (withCond) {
Projector.make(arrowSchema, fieldNodesList, SelectionVectorType.SV_INT16)
new FilterProjector(projectionSchema, resultSchema, fieldNodesList, SelectionVectorType.SV_INT16)
} else {
Projector.make(arrowSchema, fieldNodesList)
new FieldOptimizedProjector(projectionSchema, resultSchema, fieldNodesList)
}
} catch {
case e =>
logError(
s"\noriginalInputAttributes is ${originalInputAttributes} ${originalInputAttributes.map(
_.dataType)}, \narrowSchema is ${arrowSchema}, \nProjection is ${prepareList.map(_._1.toProtobuf)}")
_.dataType)}, \nprojectionSchema is ${projectionSchema}, \nresultSchema is ${resultSchema}, \nProjection is ${prepareList.map(_._1.toProtobuf)}")
throw e
}
}
Expand Down Expand Up @@ -258,28 +264,19 @@ class ColumnarConditionProjector(
// for now, we either filter one columnarBatch who has valid rows or we only need to do project
// either scenario we will need to output one columnarBatch.
beforeEval = System.nanoTime()
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(numRows, resultSchema).toArray
val outputVectors = resultColumnVectors
.map(columnVector => {
columnVector.getValueVector()
})
.toList
.asJava

val cols = projectOrdinalList.map(i => {
columnarBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector()
})
input = ConverterUtils.createArrowRecordBatch(columnarBatch.numRows, cols)
if (conditioner != null) {
projector.evaluate(input, selectionVector, outputVectors);
val outputBatch = if (conditioner != null) {
projector.evaluate(input, numRows, selectionVector);
} else {
projector.evaluate(input, outputVectors);
projector.evaluate(input);
}

ConverterUtils.releaseArrowRecordBatch(input)
val outputBatch =
new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), numRows)

proc_time += ((System.nanoTime() - beforeEval) / (1000 * 1000))
resColumnarBatch = outputBatch
true
Expand Down Expand Up @@ -448,4 +445,190 @@ object ColumnarConditionProjector extends Logging {
numOutputRows,
procTime)
}

trait ProjectorWrapper {
def evaluate(recordBatch: ArrowRecordBatch): ColumnarBatch = {
throw new UnsupportedOperationException
}

def evaluate(recordBatch: ArrowRecordBatch, numRows: Int, selectionVector: SelectionVector): ColumnarBatch = {
throw new UnsupportedOperationException
}

def close(): Unit
}

/**
* Proxy projector that is optimized for field projections.
*/
class FieldOptimizedProjector(projectionSchema: Schema, resultSchema: Schema,
exprs: java.util.List[ExpressionTree]) extends ProjectorWrapper {

val fieldExprs = ListBuffer[(ExpressionTree, Int)]()
val fieldExprNames = new util.HashSet[String]()

/**
* nonFieldExprs may include fields that are already appeared in projection list.
* To avoid sharing same buffers over output columns.
*/
val nonFieldExprs = ListBuffer[(ExpressionTree, Int)]()

exprs.asScala.zipWithIndex.foreach {
case (expr, i) =>
val root = getRoot(expr)
if (fieldClazz.isInstance(root) && !fieldExprNames.contains(getField(root).getName)) {
fieldExprs.append((expr, i))
fieldExprNames.add(getField(root).getName)
} else {
nonFieldExprs.append((expr, i))
}
}

val fieldResultSchema = new Schema(
fieldExprs.map {
case (_, i) =>
resultSchema.getFields.get(i)
}.asJava)

val nonFieldResultSchema = new Schema(
nonFieldExprs.map {
case (_, i) =>
resultSchema.getFields.get(i)
}.asJava)

val nonFieldProjector: Option[Projector] =
if (nonFieldExprs.isEmpty) {
None
} else {
Some(
Projector.make(
projectionSchema, nonFieldExprs.map {
case (e, _) => e
}.toList.asJava))
}

override def evaluate(recordBatch: ArrowRecordBatch): ColumnarBatch = {
val numRows = recordBatch.getLength
val projectedAVs = new Array[ArrowWritableColumnVector](exprs.size())

// Execute expression-based projections
val nonFieldResultColumnVectors =
ArrowWritableColumnVector.allocateColumns(numRows,
ArrowUtils.fromArrowSchema(nonFieldResultSchema))

val outputVectors = nonFieldResultColumnVectors
.map(columnVector => {
columnVector.getValueVector
})
.toList
.asJava

nonFieldProjector.foreach {
_.evaluate(recordBatch, outputVectors)
}

var k: Int = 0
nonFieldExprs.foreach {
case (_, i) =>
projectedAVs(i) = nonFieldResultColumnVectors(k)
k += 1
}

val inAVs = ArrowWritableColumnVector.loadColumns(numRows, projectionSchema, recordBatch)

fieldExprs.foreach {
case (fieldExpr, i) =>
val field = getField(getRoot(fieldExpr))
var found = false
breakable {
for (j <- 0 until projectionSchema.getFields.size()) {
val projField = projectionSchema.getFields.get(j)
if (Objects.equals(field.getName, projField.getName)) {
// Found field in input schema
if (projectedAVs(i) != null) {
throw new IllegalStateException()
}
val vector = inAVs(j)
projectedAVs(i) = vector
vector.retain()
found = true
break
}
}
}
if (!found) {
throw new IllegalArgumentException("Field not found for projection: " + field.getName)
}
}

inAVs.foreach(_.close())

// Projected vector count check
projectedAVs.foreach {
arrowVector =>
if (arrowVector == null) {
throw new IllegalStateException()
}
}

val outputBatch =
new ColumnarBatch(projectedAVs.map(_.asInstanceOf[ColumnVector]), numRows)

outputBatch
}

override def close() = {
nonFieldProjector.foreach(_.close())
}
}

class FilterProjector(projectionSchema: Schema, resultSchema: Schema,
exprs: java.util.List[ExpressionTree],
selectionVectorType: GandivaTypes.SelectionVectorType) extends ProjectorWrapper {
val projector = Projector.make(projectionSchema, exprs, selectionVectorType)

override def evaluate(recordBatch: ArrowRecordBatch, numRows: Int,
selectionVector: SelectionVector): ColumnarBatch = {
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(numRows, ArrowUtils.fromArrowSchema(resultSchema))

val outputVectors = resultColumnVectors
.map(columnVector => {
columnVector.getValueVector
})
.toList
.asJava

projector.evaluate(recordBatch, selectionVector, outputVectors)

val outputBatch =
new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), numRows)

outputBatch
}

override def close(): Unit = {
projector.close()
}
}

val treeClazz = classOf[ExpressionTree]
val rootField = treeClazz.getDeclaredField("root")
val fieldClazz = Class.forName("org.apache.arrow.gandiva.expression.FieldNode")
val fieldField = fieldClazz.getDeclaredField("field")

rootField.setAccessible(true)
fieldField.setAccessible(true)

def getRoot(expressionTree: ExpressionTree): TreeNode = {
rootField.get(expressionTree).asInstanceOf[TreeNode]
}

def getField(fieldNode: Any): Field = {
if (!fieldClazz.isInstance(fieldNode)) {
throw new IllegalArgumentException
}
fieldField.get(fieldNode).asInstanceOf[Field]

}
}

0 comments on commit 155afa4

Please sign in to comment.