Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-732] Support Map complex type in Shuffle (#749)
Browse files Browse the repository at this point in the history
* [NSE-732] Support Struct and Map nested types in Shuffle

* format C code

* Turn on Map and Struct

* Fix Typo

* Troubleshoot recordbatch building

* Fix Clang stype

* Reserve previous check way

* Fix clang stype

* Add check for nested complex types

* First draft commit

* Add uts

* clean code

* Format code

* Improve ArrowRowToColumnarExec buildcheck

* Add Fall back with complex type in Partitioning keys
  • Loading branch information
zhixingheyi-tian authored Mar 8, 2022
1 parent 8a4a741 commit 20379cd
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,11 @@ private ArrowVectorWriter createVectorWriter(ValueVector vector) {
return new TimestampMicroWriter((TimeStampVector) vector);
} else if (vector instanceof MapVector) {
MapVector mapVector = (MapVector) vector;
return new MapWriter(mapVector);
final StructVector structVector = (StructVector) mapVector.getDataVector();
final FieldVector keyChild = structVector.getChild(MapVector.KEY_NAME);
final FieldVector valueChild = structVector.getChild(MapVector.VALUE_NAME);
return new MapWriter(mapVector, createVectorWriter(keyChild),
createVectorWriter(valueChild));
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
ArrowVectorWriter elementVector = createVectorWriter(listVector.getDataVector());
Expand Down Expand Up @@ -1949,8 +1953,36 @@ void setNotNull(int rowId) {
}

private static class MapWriter extends ArrowVectorWriter {
MapWriter(ValueVector vector) {
private final MapVector mapVector;
private final ArrowVectorWriter keyWriter;
private final ArrowVectorWriter valueWriter;

MapWriter(MapVector vector, ArrowVectorWriter keyWriter, ArrowVectorWriter valueWriter) {
super(vector);
this.mapVector = vector;
this.keyWriter = keyWriter;
this.valueWriter = valueWriter;
}

public ArrowVectorWriter getKeyWriter() {
return keyWriter;
}

public ArrowVectorWriter getValueWriter() {
return valueWriter;
}

@Override
void setArray(int rowId, int offset, int length) {
int index = rowId * ListVector.OFFSET_WIDTH;
mapVector.getOffsetBuffer().setInt(index, offset);
mapVector.getOffsetBuffer().setInt(index + ListVector.OFFSET_WIDTH, offset + length);
mapVector.setNotNull(rowId);
}

@Override
final void setNull(int rowId) {
mapVector.setNull(rowId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ArrowRowToColumnarExec(child: SparkPlan) extends RowToColumnarExec(child =
case d: DecimalType =>
case d: TimestampType =>
case d: BinaryType =>
case d: ArrayType =>
case d: ArrayType => ConverterUtils.checkIfTypeSupported(d.elementType)
case _ =>
throw new UnsupportedOperationException(s"${field.dataType} " +
s"is not supported in ArrowRowToColumnarExec.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import java.io.{InputStream, OutputStream}
import java.util
import java.util.concurrent.TimeUnit.SECONDS

import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.TimeUnit
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
Expand Down Expand Up @@ -522,6 +523,7 @@ object ConverterUtils extends Logging {
for ( structField <- d.fields ) {
checkIfTypeSupported(structField.dataType)
}
case d: MapType =>
case d: BooleanType =>
case d: ByteType =>
case d: ShortType =>
Expand All @@ -537,26 +539,35 @@ object ConverterUtils extends Logging {
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def createArrowField(name: String, dt: DataType): Field = dt match {
def createArrowField(name: String, dt: DataType, nullable: Boolean = true): Field = dt match {
case at: ArrayType =>
new Field(
name,
FieldType.nullable(ArrowType.List.INSTANCE),
new FieldType(nullable, ArrowType.List.INSTANCE, null),
Lists.newArrayList(createArrowField(s"${name}_${dt}", at.elementType)))
case mt: MapType =>
throw new UnsupportedOperationException(s"${dt} is not supported yet")
case st: StructType =>
val fieldlist = new util.ArrayList[Field]
var structField = null
for ( structField <- st.fields ) {
fieldlist.add(createArrowField(structField.name, structField.dataType))
fieldlist.add(createArrowField(structField.name, structField.dataType, structField.nullable))
}
new Field(
name,
FieldType.nullable(ArrowType.Struct.INSTANCE),
new FieldType(nullable, ArrowType.Struct.INSTANCE, null),
fieldlist)
case mt: MapType =>
// Note: Map Type struct can not be null, Struct Type key field can not be null
new Field(
name,
new FieldType(nullable, new ArrowType.Map(false), null),
Lists.newArrayList(createArrowField(MapVector.DATA_VECTOR_NAME,
new StructType()
.add(MapVector.KEY_NAME, mt.keyType, false)
.add(MapVector.VALUE_NAME, mt.valueType, mt.valueContainsNull),
nullable = false
)))
case _ =>
Field.nullable(name, CodeGeneration.getResultType(dt))
new Field (name, new FieldType(nullable, CodeGeneration.getResultType(dt), null), null)
}

def createArrowField(attr: Attribute): Field =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,30 @@ case class ColumnarShuffleExchangeExec(
// check input datatype
for (attr <- child.output) {
try {
ConverterUtils.checkIfNestTypeSupported(attr.dataType)
ConverterUtils.createArrowField(attr)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledExchangeExec.")
}
}

// Check partitioning keys
outputPartitioning match {
case HashPartitioning(exprs, n) =>
exprs.zipWithIndex.foreach {
case (expr, i) =>
val attr = ConverterUtils.getAttrFromExpr(expr)
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledExchangeExec Partitioning.")
}
}
case _ =>
}
}

val serializer: Serializer = new ArrowColumnarBatchSerializer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
* limitations under the License.
*/

package com.intel.oap.misc
package org.apache.spark.shuffle

import java.nio.file.Files

import com.intel.oap.tpc.util.TPCRunner
import org.apache.log4j.{Level, LogManager}
import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.test.SharedSparkSession

Expand Down Expand Up @@ -58,12 +60,9 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.oap.sql.columnar.sortmergejoin.lazyread", "true")
.set("spark.oap.sql.columnar.autorelease", "false")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.shuffle.partitions", "50")
.set("spark.sql.adaptive.coalescePartitions.initialPartitionNum", "5")
.set("spark.oap.sql.columnar.shuffledhashjoin.buildsizelimit", "200m")
.set("spark.oap.sql.columnar.rowtocolumnar", "false")
.set("spark.oap.sql.columnar.columnartorow", "false")
return conf
}

Expand All @@ -76,7 +75,17 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
lPath = lfile.getAbsolutePath
spark.range(2).select(col("id"), expr("1").as("kind"),
expr("array(1, 2)").as("arr_field"),
expr("struct(1, 2)").as("struct_field"))
expr("array(array(1, 2), array(3, 4))").as("arr_arr_field"),
expr("array(struct(1, 2), struct(1, 2))").as("arr_struct_field"),
expr("array(map(1, 2), map(3,4))").as("arr_map_field"),
expr("struct(1, 2)").as("struct_field"),
expr("struct(1, struct(1, 2))").as("struct_struct_field"),
expr("struct(1, array(1, 2))").as("struct_array_field"),
expr("map(1, 2)").as("map_field"),
expr("map(1, map(3,4))").as("map_map_field"),
expr("map(1, array(1, 2))").as("map_arr_field"),
expr("map(struct(1, 2), 2)").as("map_struct_field"))
.coalesce(1)
.write
.format("parquet")
.mode("overwrite")
Expand All @@ -88,6 +97,7 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
spark.range(2).select(col("id"), expr("id % 2").as("kind"),
expr("array(1, 2)").as("arr_field"),
expr("struct(1, 2)").as("struct_field"))
.coalesce(1)
.write
.format("parquet")
.mode("overwrite")
Expand All @@ -101,16 +111,97 @@ class ComplexTypeSuite extends QueryTest with SharedSparkSession {
val df = spark.sql("SELECT ltab.arr_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.count() == 2)
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count == 2)
}

test("Test Nest Array in Shuffle split") {
val df = spark.sql("SELECT ltab.arr_arr_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count == 2)
}

test("Test Array_Struct in Shuffle split") {
val df = spark.sql("SELECT ltab.arr_struct_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count == 2)
}

test("Test Array_Map in Shuffle split") {
val df = spark.sql("SELECT ltab.arr_map_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count == 2)
}

test("Test Struct in Shuffle stage") {
val df = spark.sql("SELECT ltab.struct_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count() == 2)
}

test("Test Nest Struct in Shuffle stage") {
val df = spark.sql("SELECT ltab.struct_struct_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count() == 2)
}

test("Test Struct_Array in Shuffle stage") {
val df = spark.sql("SELECT ltab.struct_array_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map in Shuffle stage") {
val df = spark.sql("SELECT ltab.map_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count() == 2)
}

test("Test Nest Map in Shuffle stage") {
val df = spark.sql("SELECT ltab.map_map_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map_Array in Shuffle stage") {
val df = spark.sql("SELECT ltab.map_arr_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count() == 2)
}

test("Test Map_Struct in Shuffle stage") {
val df = spark.sql("SELECT ltab.map_struct_field FROM ltab, rtab WHERE ltab.kind = rtab.kind")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarShuffleExchangeExec]).isDefined)
assert(df.count() == 2)
}

test("Test Fall back with complex type in Partitioning keys") {
val df = spark.sql("SELECT ltab.arr_field FROM ltab, rtab WHERE ltab.arr_field = rtab.arr_field")
df.explain(true)
df.show()
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined)
}

override def afterAll(): Unit = {
super.afterAll()
}
Expand Down
Loading

0 comments on commit 20379cd

Please sign in to comment.