Skip to content

Commit

Permalink
[SPARK-31854][SQL] Invoke in MapElementsExec should not propagate null
Browse files Browse the repository at this point in the history
This PR intends to fix a bug of `Dataset.map` below when the whole-stage codegen enabled;
```
scala> val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS()

scala> sql("SET spark.sql.codegen.wholeStage=true")

scala> ds.map(v=>(v,v)).explain
== Physical Plan ==
*(1) SerializeFromObject [assertnotnull(input[0, scala.Tuple2, true])._1.intValue AS _1#69, assertnotnull(input[0, scala.Tuple2, true])._2.intValue AS _2#70]
+- *(1) MapElements <function1>, obj#68: scala.Tuple2
   +- *(1) DeserializeToObject staticinvoke(class java.lang.Integer, ObjectType(class java.lang.Integer), valueOf, value#1, true, false), obj#67: java.lang.Integer
      +- LocalTableScan [value#1]

// `AssertNotNull` in `SerializeFromObject` will fail;
scala> ds.map(v => (v, v)).show()
java.lang.NullPointerException: Null value appeared in non-nullable fails:
top level Product input object
If the schema is inferred from a Scala tuple/case class, or a Java bean, please try to use scala.Option[_] or other nullable types (e.g. java.lang.Integer instead of int/scala.Int).

// When the whole-stage codegen disabled, the query works well;
scala> sql("SET spark.sql.codegen.wholeStage=false")
scala> ds.map(v=>(v,v)).show()
+----+----+
|  _1|  _2|
+----+----+
|   1|   1|
|null|null|
+----+----+
```
A root cause is that `Invoke` used in `MapElementsExec` propagates input null, and then [AssertNotNull](https://github.com/apache/spark/blob/1b780f364bfbb46944fe805a024bb6c32f5d2dde/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala#L253-L255) in `SerializeFromObject` fails because a top-level row becomes null. So, `MapElementsExec` should not return `null` but `(null, null)`.

NOTE: the generated code of the query above in the current master;
```
/* 033 */   private void mapelements_doConsume_0(java.lang.Integer mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws java.io.IOException {
/* 034 */     boolean mapelements_isNull_1 = true;
/* 035 */     scala.Tuple2 mapelements_value_1 = null;
/* 036 */     if (!false) {
/* 037 */       mapelements_resultIsNull_0 = false;
/* 038 */
/* 039 */       if (!mapelements_resultIsNull_0) {
/* 040 */         mapelements_resultIsNull_0 = mapelements_exprIsNull_0_0;
/* 041 */         mapelements_mutableStateArray_0[0] = mapelements_expr_0_0;
/* 042 */       }
/* 043 */
/* 044 */       mapelements_isNull_1 = mapelements_resultIsNull_0;
/* 045 */       if (!mapelements_isNull_1) {
/* 046 */         Object mapelements_funcResult_0 = null;
/* 047 */         mapelements_funcResult_0 = ((scala.Function1) references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]);
/* 048 */
/* 049 */         if (mapelements_funcResult_0 != null) {
/* 050 */           mapelements_value_1 = (scala.Tuple2) mapelements_funcResult_0;
/* 051 */         } else {
/* 052 */           mapelements_isNull_1 = true;
/* 053 */         }
/* 054 */
/* 055 */       }
/* 056 */     }
/* 057 */
/* 058 */     serializefromobject_doConsume_0(mapelements_value_1, mapelements_isNull_1);
/* 059 */
/* 060 */   }
```

The generated code w/ this fix;
```
/* 032 */   private void mapelements_doConsume_0(java.lang.Integer mapelements_expr_0_0, boolean mapelements_exprIsNull_0_0) throws java.io.IOException {
/* 033 */     boolean mapelements_isNull_1 = true;
/* 034 */     scala.Tuple2 mapelements_value_1 = null;
/* 035 */     if (!false) {
/* 036 */       mapelements_mutableStateArray_0[0] = mapelements_expr_0_0;
/* 037 */
/* 038 */       mapelements_isNull_1 = false;
/* 039 */       if (!mapelements_isNull_1) {
/* 040 */         Object mapelements_funcResult_0 = null;
/* 041 */         mapelements_funcResult_0 = ((scala.Function1) references[1] /* literal */).apply(mapelements_mutableStateArray_0[0]);
/* 042 */
/* 043 */         if (mapelements_funcResult_0 != null) {
/* 044 */           mapelements_value_1 = (scala.Tuple2) mapelements_funcResult_0;
/* 045 */           mapelements_isNull_1 = false;
/* 046 */         } else {
/* 047 */           mapelements_isNull_1 = true;
/* 048 */         }
/* 049 */
/* 050 */       }
/* 051 */     }
/* 052 */
/* 053 */     serializefromobject_doConsume_0(mapelements_value_1, mapelements_isNull_1);
/* 054 */
/* 055 */   }
```

Bugfix.

No.

Added tests.

Closes #28681 from maropu/SPARK-31854.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit b806fc4)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
maropu authored and cloud-fan committed Jun 1, 2020
1 parent a5a8ec2 commit efa0269
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,12 @@ case class MapElementsExec(
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
val (funcClass, funcName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output)
val callFunc = Invoke(funcObj, funcName, outputObjectType, child.output, propagateNull = false)

val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx)

Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1901,6 +1901,16 @@ class DatasetSuite extends QueryTest

assert(active eq SparkSession.getActiveSession.get)
}

test("SPARK-31854: Invoke in MapElementsExec should not propagate null") {
Seq("true", "false").foreach { wholeStage =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStage) {
val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS()
val expectedAnswer = Seq[(Integer, Integer)]((1, 1), (null, null))
checkDataset(ds.map(v => (v, v)), expectedAnswer: _*)
}
}
}
}

object AssertExecutionId {
Expand Down

0 comments on commit efa0269

Please sign in to comment.