Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-732] Support Struct complex type in Shuffle (#733)
Browse files Browse the repository at this point in the history
* [NSE-732] Support Struct and Map nested types in Shuffle

* format C code

* Turn on Map and Struct

* Fix Typo

* Troubleshoot recordbatch building

* Fix Clang stype

* Reserve previous check way

* Fix clang stype

* Add check for nested complex types
  • Loading branch information
zhixingheyi-tian authored Mar 1, 2022
1 parent 0cbc442 commit 7f59755
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1930,8 +1930,21 @@ final void setNull(int rowId) {
}

private static class StructWriter extends ArrowVectorWriter {
private final StructVector writer;

StructWriter(StructVector vector, ArrowVectorWriter[] children) {
super(vector);
this.writer = vector;
}

@Override
void setNull(int rowId) {
writer.setNull(rowId);
}

@Override
void setNotNull(int rowId) {
writer.setIndexDefined(rowId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.{ArrowFieldNode, ArrowRecordBatch, IpcOption, MessageChannelReader, MessageResult, MessageSerializer}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand All @@ -48,13 +47,13 @@ import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import io.netty.buffer.{ByteBuf, ByteBufAllocator, ByteBufOutputStream}
import java.nio.channels.{Channels, WritableByteChannel}

import com.google.common.collect.Lists
import java.io.{InputStream, OutputStream}
import java.util
import java.util.concurrent.TimeUnit.SECONDS

import org.apache.arrow.vector.types.TimeUnit
Expand Down Expand Up @@ -517,6 +516,27 @@ object ConverterUtils extends Logging {
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def checkIfNestTypeSupported(dt: DataType): Unit = dt match {
case d: ArrayType => checkIfTypeSupported(d.elementType)
case d: StructType =>
for ( structField <- d.fields ) {
checkIfTypeSupported(structField.dataType)
}
case d: BooleanType =>
case d: ByteType =>
case d: ShortType =>
case d: IntegerType =>
case d: LongType =>
case d: FloatType =>
case d: DoubleType =>
case d: StringType =>
case d: DateType =>
case d: DecimalType =>
case d: TimestampType =>
case _ =>
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def createArrowField(name: String, dt: DataType): Field = dt match {
case at: ArrayType =>
new Field(
Expand All @@ -526,7 +546,15 @@ object ConverterUtils extends Logging {
case mt: MapType =>
throw new UnsupportedOperationException(s"${dt} is not supported yet")
case st: StructType =>
throw new UnsupportedOperationException(s"${dt} is not supported yet")
val fieldlist = new util.ArrayList[Field]
var structField = null
for ( structField <- st.fields ) {
fieldlist.add(createArrowField(structField.name, structField.dataType))
}
new Field(
name,
FieldType.nullable(ArrowType.Struct.INSTANCE),
fieldlist)
case _ =>
Field.nullable(name, CodeGeneration.getResultType(dt))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ case class ColumnarShuffleExchangeExec(
// check input datatype
for (attr <- child.output) {
try {
ConverterUtils.createArrowField(attr)
ConverterUtils.checkIfNestTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.misc

import java.nio.file.Files

import com.intel.oap.tpc.util.TPCRunner
import org.apache.log4j.{Level, LogManager}
import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.test.SharedSparkSession

class ComplexTypeSuite extends QueryTest with SharedSparkSession {

private val MAX_DIRECT_MEMORY = "5000m"
private var runner: TPCRunner = _

private var lPath: String = _
private var rPath: String = _
private val scale = 100

override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
conf.set("spark.memory.offHeap.size", String.valueOf(MAX_DIRECT_MEMORY))
.set("spark.plugins", "com.intel.oap.GazellePlugin")
.set("spark.sql.codegen.wholeStage", "true")
.set("spark.sql.sources.useV1SourceList", "")
.set("spark.oap.sql.columnar.tmp_dir", "/tmp/")
.set("spark.sql.columnar.sort.broadcastJoin", "true")
.set("spark.storage.blockManagerSlaveTimeoutMs", "3600000")
.set("spark.executor.heartbeatInterval", "3600000")
.set("spark.network.timeout", "3601s")
.set("spark.oap.sql.columnar.preferColumnar", "true")
.set("spark.oap.sql.columnar.sortmergejoin", "true")
.set("spark.sql.columnar.codegen.hashAggregate", "false")
.set("spark.sql.columnar.sort", "true")
.set("spark.sql.columnar.window", "true")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.unsafe.exceptionOnMemoryLeak", "false")
.set("spark.network.io.preferDirectBufs", "false")
.set("spark.sql.sources.useV1SourceList", "arrow,parquet")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.oap.sql.columnar.sortmergejoin.lazyread", "true")
.set("spark.oap.sql.columnar.autorelease", "false")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.shuffle.partitions", "50")
.set("spark.sql.adaptive.coalescePartitions.initialPartitionNum", "5")
.set("spark.oap.sql.columnar.shuffledhashjoin.buildsizelimit", "200m")
.set("spark.oap.sql.columnar.rowtocolumnar", "false")
.set("spark.oap.sql.columnar.columnartorow", "false")
return conf
}

override def beforeAll(): Unit = {
super.beforeAll()
LogManager.getRootLogger.setLevel(Level.WARN)

val lfile = Files.createTempFile("", ".parquet").toFile
lfile.deleteOnExit()
lPath = lfile.getAbsolutePath
spark.range(2).select(col("id"), expr("1").as("kind"),
expr("array(1, 2)").as("arr_field"),
expr("struct(1, 2)").as("struct_field"))
.write
.format("parquet")
.mode("overwrite")
.parquet(lPath)

val rfile = Files.createTempFile("", ".parquet").toFile
rfile.deleteOnExit()
rPath = rfile.getAbsolutePath
spark.range(2).select(col("id"), expr("id % 2").as("kind"),
expr("array(1, 2)").as("arr_field"),
expr("struct(1, 2)").as("struct_field"))
.write
.format("parquet")
.mode("overwrite")
.parquet(rPath)

spark.catalog.createTable("ltab", lPath, "arrow")
spark.catalog.createTable("rtab", rPath, "arrow")
}

test("Test Array in Shuffle split") {
val df = spark.sql("SELECT ltab.arr_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.count() == 2)
}

test("Test Struct in Shuffle stage") {
val df = spark.sql("SELECT ltab.struct_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.count() == 2)
}

override def afterAll(): Unit = {
super.afterAll()
}
}
31 changes: 30 additions & 1 deletion native-sql-engine/cpp/src/jni/jni_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ arrow::Status AppendBuffers(std::shared_ptr<arrow::Array> column,
(*buffers).push_back(list_array->value_offsets());
RETURN_NOT_OK(AppendBuffers(list_array->values(), buffers));
} break;
case arrow::Type::STRUCT: {
auto struct_array = std::dynamic_pointer_cast<arrow::StructArray>(column);
(*buffers).push_back(struct_array->null_bitmap());
for (int i = 0; i < struct_array->num_fields(); ++i) {
RETURN_NOT_OK(AppendBuffers(struct_array->field(i), buffers));
}
} break;
default: {
for (auto& buffer : column->data()->buffers) {
(*buffers).push_back(buffer);
Expand Down Expand Up @@ -198,7 +205,29 @@ arrow::Status MakeArrayData(std::shared_ptr<arrow::DataType> type, int num_rows,
auto list_array =
arrow::ListArray::FromArrays(*offset_array, *child_array).ValueOrDie();
*arr_data = list_array->data();

} break;
case arrow::Type::STRUCT: {
int64_t null_count = arrow::kUnknownNullCount;
std::vector<std::shared_ptr<arrow::Buffer>> buffers;
if (*buf_idx_ptr >= in_bufs_len) {
return arrow::Status::Invalid("insufficient number of in_buf_addrs");
}
if (in_bufs[*buf_idx_ptr]->size() == 0) {
null_count = 0;
}
buffers.push_back(in_bufs[*buf_idx_ptr]);
*buf_idx_ptr += 1;

ArrayDataVector struct_child_data_vec;
for (int i = 0; i < type->num_fields(); ++i) {
std::shared_ptr<arrow::Field> field = type->field(i);
std::shared_ptr<arrow::ArrayData> struct_child_data;
RETURN_NOT_OK(MakeArrayData(field->type(), -1, in_bufs, in_bufs_len,
&struct_child_data, buf_idx_ptr));
struct_child_data_vec.push_back(struct_child_data);
}
*arr_data = arrow::ArrayData::Make(type, num_rows, std::move(buffers),
struct_child_data_vec, null_count);
} break;
default:
return arrow::Status::NotImplemented("MakeArrayData for type ", type->ToString(),
Expand Down
43 changes: 12 additions & 31 deletions native-sql-engine/cpp/src/shuffle/splitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,10 @@ arrow::Status Splitter::Init() {
case arrow::LargeStringType::type_id:
large_binary_array_idx_.push_back(i);
break;
case arrow::ListType::type_id:
list_array_idx_.push_back(i);
break;
case arrow::StructType::type_id:
case arrow::MapType::type_id:
case arrow::LargeListType::type_id:
case arrow::ListType::type_id:
list_array_idx_.push_back(i);
break;
case arrow::NullType::type_id:
Expand Down Expand Up @@ -511,6 +511,9 @@ arrow::Status Splitter::CacheRecordBatch(int32_t partition_id, bool reset_buffer
large_binary_idx++;
break;
}
case arrow::StructType::type_id:
case arrow::MapType::type_id:
case arrow::LargeListType::type_id:
case arrow::ListType::type_id: {
auto& builder = partition_list_builders_[list_idx][partition_id];
if (reset_buffers) {
Expand All @@ -524,19 +527,6 @@ arrow::Status Splitter::CacheRecordBatch(int32_t partition_id, bool reset_buffer
list_idx++;
break;
}
case arrow::LargeListType::type_id: {
auto& builder = partition_list_builders_[list_idx][partition_id];
if (reset_buffers) {
RETURN_NOT_OK(builder->Finish(&arrays[i]));
builder->Reset();
} else {
RETURN_NOT_OK(builder->Finish(&arrays[i]));
builder->Reset();
RETURN_NOT_OK(builder->Reserve(num_rows));
}
list_idx++;
break;
}
case arrow::NullType::type_id: {
arrays[i] = arrow::MakeArray(arrow::ArrayData::Make(
arrow::null(), num_rows, {nullptr, nullptr}, num_rows));
Expand Down Expand Up @@ -618,6 +608,9 @@ arrow::Status Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t n
large_binary_idx++;
break;
}
case arrow::StructType::type_id:
case arrow::MapType::type_id:
case arrow::LargeListType::type_id:
case arrow::ListType::type_id: {
std::unique_ptr<arrow::ArrayBuilder> array_builder;
RETURN_NOT_OK(
Expand All @@ -628,16 +621,6 @@ arrow::Status Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t n
list_idx++;
break;
}
case arrow::LargeListType::type_id: {
std::unique_ptr<arrow::ArrayBuilder> array_builder;
RETURN_NOT_OK(
MakeBuilder(options_.memory_pool, column_type_id_[i], &array_builder));
assert(array_builder != nullptr);
RETURN_NOT_OK(array_builder->Reserve(new_size));
new_list_builders.push_back(std::move(array_builder));
list_idx++;
break;
}
case arrow::NullType::type_id:
break;
default: {
Expand Down Expand Up @@ -687,12 +670,10 @@ arrow::Status Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t n
std::move(new_large_binary_builders[large_binary_idx]);
large_binary_idx++;
break;
case arrow::ListType::type_id:
partition_list_builders_[list_idx][partition_id] =
std::move(new_list_builders[list_idx]);
list_idx++;
break;
case arrow::StructType::type_id:
case arrow::MapType::type_id:
case arrow::LargeListType::type_id:
case arrow::ListType::type_id:
partition_list_builders_[list_idx][partition_id] =
std::move(new_list_builders[list_idx]);
list_idx++;
Expand Down
2 changes: 2 additions & 0 deletions native-sql-engine/cpp/src/shuffle/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ static arrow::Result<std::vector<std::shared_ptr<arrow::DataType>>> ToSplitterTy
case arrow::StringType::type_id:
case arrow::LargeBinaryType::type_id:
case arrow::LargeStringType::type_id:
case arrow::StructType::type_id:
case arrow::MapType::type_id:
case arrow::ListType::type_id:
case arrow::LargeListType::type_id:
case arrow::Decimal128Type::type_id:
Expand Down
Loading

0 comments on commit 7f59755

Please sign in to comment.