Skip to content

Commit

Permalink
chore: Revise batch pull approach to more follow C Data interface sem…
Browse files Browse the repository at this point in the history
…antics (#893)

* chore: Revise batch pull approach to more follow C Data interface semantics

* fix clippy

* Remove ExportedBatch
  • Loading branch information
viirya authored Sep 3, 2024
1 parent 033fe6f commit 046a62c
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 155 deletions.
44 changes: 0 additions & 44 deletions common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala

This file was deleted.

69 changes: 29 additions & 40 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,50 +47,39 @@ class NativeUtil {
* an exported batches object containing an array containing number of rows + pairs of memory
* addresses in the format of (address of Arrow array, address of Arrow schema)
*/
def exportBatch(batch: ColumnarBatch): ExportedBatch = {
val exportedVectors = mutable.ArrayBuffer.empty[Long]
exportedVectors += batch.numRows()

// Run checks prior to exporting the batch
(0 until batch.numCols()).foreach { index =>
val c = batch.column(index)
if (!c.isInstanceOf[CometVector]) {
batch.close()
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}

val arrowSchemas = mutable.ArrayBuffer.empty[ArrowSchema]
val arrowArrays = mutable.ArrayBuffer.empty[ArrowArray]

def exportBatch(
arrayAddrs: Array[Long],
schemaAddrs: Array[Long],
batch: ColumnarBatch): Int = {
(0 until batch.numCols()).foreach { index =>
val cometVector = batch.column(index).asInstanceOf[CometVector]
val valueVector = cometVector.getValueVector

val provider = if (valueVector.getField.getDictionary != null) {
cometVector.getDictionaryProvider
} else {
null
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector

val provider = if (valueVector.getField.getDictionary != null) {
a.getDictionaryProvider
} else {
null
}

// The array and schema structures are allocated by native side.
// Don't need to deallocate them here.
val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
val arrowArray = ArrowArray.wrap(arrayAddrs(index))
Data.exportVector(
allocator,
getFieldVector(valueVector, "export"),
provider,
arrowArray,
arrowSchema)
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}

val arrowSchema = ArrowSchema.allocateNew(allocator)
val arrowArray = ArrowArray.allocateNew(allocator)
arrowSchemas += arrowSchema
arrowArrays += arrowArray
Data.exportVector(
allocator,
getFieldVector(valueVector, "export"),
provider,
arrowArray,
arrowSchema)

exportedVectors += arrowArray.memoryAddress()
exportedVectors += arrowSchema.memoryAddress()
}

ExportedBatch(exportedVectors.toArray, arrowSchemas.toArray, arrowArrays.toArray)
batch.numRows()
}

/**
Expand Down
100 changes: 58 additions & 42 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use futures::Stream;
use itertools::Itertools;
use std::rc::Rc;
use std::{
any::Any,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use futures::Stream;
use itertools::Itertools;

use arrow::compute::{cast_with_options, CastOptions};
use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_data::ArrayData;
use arrow_schema::{DataType, Field, Schema, SchemaRef};

use crate::{
errors::CometError,
execution::{
Expand All @@ -38,17 +33,22 @@ use crate::{
},
jvm_bridge::{jni_call, JVMClasses},
};
use arrow::compute::{cast_with_options, CastOptions};
use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
use arrow_data::ffi::FFI_ArrowArray;
use arrow_data::ArrayData;
use arrow_schema::ffi::FFI_ArrowSchema;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::{
execution::TaskContext,
physical_expr::*,
physical_plan::{ExecutionPlan, *},
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult};
use jni::{
objects::{GlobalRef, JLongArray, JObject, ReleaseMode},
sys::jlongArray,
};
use jni::objects::JValueGen;
use jni::objects::{GlobalRef, JObject};
use jni::sys::jsize;

/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file
/// scan or the result of reading a broadcast or shuffle exchange.
Expand Down Expand Up @@ -86,7 +86,7 @@ impl ScanExec {
// may end up either unpacking dictionary arrays or dictionary-encoding arrays.
// Dictionary-encoded primitive arrays are always unpacked.
let first_batch = if let Some(input_source) = input_source.as_ref() {
ScanExec::get_next(exec_context_id, input_source.as_obj())?
ScanExec::get_next(exec_context_id, input_source.as_obj(), data_types.len())?
} else {
InputBatch::EOF
};
Expand Down Expand Up @@ -153,6 +153,7 @@ impl ScanExec {
let next_batch = ScanExec::get_next(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
self.data_types.len(),
)?;
*current_batch = Some(next_batch);
}
Expand All @@ -161,7 +162,11 @@ impl ScanExec {
}

/// Invokes JNI call to get next batch.
fn get_next(exec_context_id: i64, iter: &JObject) -> Result<InputBatch, CometError> {
fn get_next(
exec_context_id: i64,
iter: &JObject,
num_cols: usize,
) -> Result<InputBatch, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
// This is a unit test. We don't need to call JNI.
return Ok(InputBatch::EOF);
Expand All @@ -175,49 +180,60 @@ impl ScanExec {
}

let mut env = JVMClasses::get_env()?;
let batch_object: JObject = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next() -> JObject)?
};

if batch_object.is_null() {
return Err(CometError::from(ExecutionError::GeneralError(format!(
"Null batch object. Plan id: {}",
exec_context_id
))));
let mut array_addrs = Vec::with_capacity(num_cols);
let mut schema_addrs = Vec::with_capacity(num_cols);

for _ in 0..num_cols {
let arrow_array = Rc::new(FFI_ArrowArray::empty());
let arrow_schema = Rc::new(FFI_ArrowSchema::empty());
let (array_ptr, schema_ptr) = (
Rc::into_raw(arrow_array) as i64,
Rc::into_raw(arrow_schema) as i64,
);

array_addrs.push(array_ptr);
schema_addrs.push(schema_ptr);
}

let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw() as jlongArray) };
// Prepare the java array parameters
let long_array_addrs = env.new_long_array(num_cols as jsize)?;
let long_schema_addrs = env.new_long_array(num_cols as jsize)?;

let addresses = unsafe { env.get_array_elements(&batch_object, ReleaseMode::NoCopyBack)? };
env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;

// First element is the number of rows.
let num_rows = unsafe { *addresses.as_ptr() as i64 };
let array_obj = JObject::from(long_array_addrs);
let schema_obj = JObject::from(long_schema_addrs);

if num_rows < 0 {
return Ok(InputBatch::EOF);
}
let array_obj = JValueGen::Object(array_obj.as_ref());
let schema_obj = JValueGen::Object(schema_obj.as_ref());

let num_rows: i32 = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)?
};

let array_num = addresses.len() - 1;
if array_num % 2 != 0 {
return Err(CometError::Internal(format!(
"Invalid number of Arrow Array addresses: {}",
array_num
)));
if num_rows == -1 {
return Ok(InputBatch::EOF);
}

let num_arrays = array_num / 2;
let array_elements = unsafe { addresses.as_ptr().add(1) };
let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_arrays);
let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_cols);

for i in 0..num_arrays {
let array_ptr = unsafe { *(array_elements.add(i * 2)) };
let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) };
for i in 0..num_cols {
let array_ptr = array_addrs[i];
let schema_ptr = schema_addrs[i];
let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?;

// TODO: validate array input data

inputs.push(make_array(array_data));

// Drop the Arcs to avoid memory leak
unsafe {
Rc::from_raw(array_ptr as *const FFI_ArrowArray);
Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);
}
}

Ok(InputBatch::new(inputs, Some(num_rows as usize)))
Expand Down
5 changes: 3 additions & 2 deletions native/core/src/jvm_bridge/batch_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use jni::signature::Primitive;
use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
Expand All @@ -37,8 +38,8 @@ impl<'a> CometBatchIterator<'a> {

Ok(CometBatchIterator {
class,
method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J")?,
method_next_ret: ReturnType::Array,
method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J[J)I")?,
method_next_ret: ReturnType::Primitive(Primitive::Int),
})
}
}
33 changes: 8 additions & 25 deletions spark/src/main/java/org/apache/comet/CometBatchIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import org.apache.spark.sql.vectorized.ColumnarBatch;

import org.apache.comet.vector.ExportedBatch;
import org.apache.comet.vector.NativeUtil;

/**
Expand All @@ -35,41 +34,25 @@ public class CometBatchIterator {
final Iterator<ColumnarBatch> input;
final NativeUtil nativeUtil;

private ExportedBatch lastBatch;

CometBatchIterator(Iterator<ColumnarBatch> input, NativeUtil nativeUtil) {
this.input = input;
this.nativeUtil = nativeUtil;
this.lastBatch = null;
}

/**
* Get the next batches of Arrow arrays. It will consume input iterator and return Arrow arrays by
* addresses. If the input iterator is done, it will return a one negative element array
* indicating the end of the iterator.
* Get the next batches of Arrow arrays.
*
* @param arrayAddrs The addresses of the ArrowArray structures.
* @param schemaAddrs The addresses of the ArrowSchema structures.
* @return the number of rows of the current batch. -1 if there is no more batch.
*/
public long[] next() {
// Native side already copied the content of ArrowSchema and ArrowArray. We should deallocate
// the ArrowSchema and ArrowArray base structures allocated in JVM.
if (lastBatch != null) {
lastBatch.close();
lastBatch = null;
}

public int next(long[] arrayAddrs, long[] schemaAddrs) {
boolean hasBatch = input.hasNext();

if (!hasBatch) {
return new long[] {-1};
return -1;
}

lastBatch = nativeUtil.exportBatch(input.next());
return lastBatch.batch();
}

public void close() {
if (lastBatch != null) {
lastBatch.close();
lastBatch = null;
}
return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next());
}
}
2 changes: 0 additions & 2 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ class CometExecIterator(
}
nativeLib.releasePlan(plan)

cometBatchIterators.foreach(_.close())

// The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released,
// so it will report:
// Caused by: java.lang.IllegalStateException: Memory was leaked by query.
Expand Down

0 comments on commit 046a62c

Please sign in to comment.