Skip to content

Commit

Permalink
[SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

A current `CodegenContext` class has immutable value or method without mutable state, too.
This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program.

## How was this patch tested?

Existing tests

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #20700 from kiszk/SPARK-23546.
  • Loading branch information
kiszk authored and hvanhovell committed Mar 5, 2018
1 parent 269cd53 commit 2ce37b5
Show file tree
Hide file tree
Showing 45 changed files with 535 additions and 497 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
ev.copy(code = oev.code)
} else {
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
val javaType = CodeGenerator.javaType(dataType)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
s"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
|$javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
s"""
boolean $resultIsNull = $inputIsNull;
${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)};
if (!$inputIsNull) {
${cast(input, result, resultIsNull)}
}
Expand All @@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val funcName = ctx.freshName("elementToString")
val elementToStringFunc = ctx.addNewFunction(funcName,
s"""
|private UTF8String $funcName(${ctx.javaType(et)} element) {
|private UTF8String $funcName(${CodeGenerator.javaType(et)} element) {
| UTF8String elementStr = null;
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
| return elementStr;
Expand All @@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|$buffer.append("[");
|if ($array.numElements() > 0) {
| if (!$array.isNullAt(0)) {
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
| $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")}));
| }
| for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
| $buffer.append(",");
| if (!$array.isNullAt($loopIndex)) {
| $buffer.append(" ");
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
| $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)}));
| }
| }
|}
Expand All @@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val dataToStringCode = castToStringCode(dataType, ctx)
ctx.addNewFunction(funcName,
s"""
|private UTF8String $funcName(${ctx.javaType(dataType)} data) {
|private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) {
| UTF8String dataStr = null;
| ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
| return dataStr;
Expand All @@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val keyToStringFunc = dataToStringFunc("keyToString", kt)
val valueToStringFunc = dataToStringFunc("valueToString", vt)
val loopIndex = ctx.freshName("loopIndex")
val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0")
val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0")
val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex)
val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex)
s"""
|$buffer.append("[");
|if ($map.numElements() > 0) {
| $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")}));
| $buffer.append($keyToStringFunc($getMapFirstKey));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt(0)) {
| $buffer.append(" ");
| $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")}));
| $buffer.append($valueToStringFunc($getMapFirstValue));
| }
| for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
| $buffer.append(", ");
| $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)}));
| $buffer.append($keyToStringFunc($getMapKeyArray));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt($loopIndex)) {
| $buffer.append(" ");
| $buffer.append($valueToStringFunc(
| ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)}));
| $buffer.append($valueToStringFunc($getMapValueArray));
| }
| }
|}
Expand All @@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
| ${if (i != 0) s"""$buffer.append(" ");""" else ""}
|
| // Append $i field into the string buffer
| ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
| ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")};
| UTF8String $fieldStr = null;
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
| $buffer.append($fieldStr);
Expand Down Expand Up @@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
$values[$j] = null;
} else {
boolean $fromElementNull = false;
${ctx.javaType(fromType)} $fromElementPrim =
${ctx.getValue(c, fromType, j)};
${CodeGenerator.javaType(fromType)} $fromElementPrim =
${CodeGenerator.getValue(c, fromType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
if ($toElementNull) {
Expand Down Expand Up @@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val fromFieldNull = ctx.freshName("ffn")
val toFieldPrim = ctx.freshName("tfp")
val toFieldNull = ctx.freshName("tfn")
val fromType = ctx.javaType(from.fields(i).dataType)
val fromType = CodeGenerator.javaType(from.fields(i).dataType)
s"""
boolean $fromFieldNull = $tmpInput.isNullAt($i);
if ($fromFieldNull) {
$tmpResult.setNullAt($i);
} else {
$fromType $fromFieldPrim =
${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
$tmpResult.setNullAt($i);
} else {
${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
}
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ abstract class Expression extends TreeNode[Expression] {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = globalIsNull
s"$globalIsNull = $localIsNull;"
} else {
""
}

val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val newValue = ctx.freshName("value")

val funcName = ctx.freshName(nodeName)
Expand Down Expand Up @@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression {
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
Expand Down Expand Up @@ -510,15 +510,15 @@ abstract class BinaryExpression extends Expression {

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${leftGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
Expand Down Expand Up @@ -654,15 +654,15 @@ abstract class TernaryExpression extends Expression {

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${leftGen.code}
${midGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}

/**
Expand Down Expand Up @@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
val partitionMaskTerm = "partitionMask"
ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = "false")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1018,11 +1018,12 @@ case class ScalaUDF(
val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
val resultConverter = s"$convertersTerm[${children.length}]"
val boxedType = CodeGenerator.boxedType(dataType)
val callFunc =
s"""
|${ctx.boxedType(dataType)} $resultTerm = null;
|$boxedType $resultTerm = null;
|try {
| $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult);
| $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
|} catch (Exception e) {
| throw new org.apache.spark.SparkException($errorMsgTerm, e);
|}
Expand All @@ -1035,7 +1036,7 @@ case class ScalaUDF(
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand All @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = "false")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -165,7 +165,7 @@ case class PreciseTimestampConversion(
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
s"""boolean ${ev.isNull} = ${eval.isNull};
|${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
override def nullSafeEval(input: Any): Any = input
Expand Down
Loading

0 comments on commit 2ce37b5

Please sign in to comment.