Skip to content

Commit

Permalink
respond to code review comments
Browse files Browse the repository at this point in the history
respond to code review comments

respond to code review comments

respond to code review comments
  • Loading branch information
dtenedor committed Sep 13, 2024
1 parent 3504124 commit 0683318
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ object FunctionRegistry {
expression[Randn]("randn"),
expression[RandStr]("randstr"),
expression[Stack]("stack"),
expressionBuilder("uniform", UniformExpressionBuilder),
expression[Uniform]("uniform"),
expression[ZeroIfNull]("zeroifnull"),
CaseWhen.registryEntry,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import scala.util.Random

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
Expand All @@ -39,8 +39,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic
with ExpressionWithRandomSeed {
trait RDG extends Expression with ExpressionWithRandomSeed {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
Expand All @@ -49,12 +48,6 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm

override def stateful: Boolean = true

override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}

override def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
Expand All @@ -63,6 +56,15 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm
override def nullable: Boolean = false

override def dataType: DataType = DoubleType
}

abstract class NondeterministicUnaryRDG
extends RDG with UnaryLike[Expression] with Nondeterministic with ExpectsInputTypes {
override def seedExpression: Expression = child

override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
}
Expand Down Expand Up @@ -105,7 +107,7 @@ private[catalyst] object ExpressionWithRandomSeed {
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG {
case class Rand(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG {

def this() = this(UnresolvedSeed, true)

Expand Down Expand Up @@ -156,7 +158,7 @@ object Rand {
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
case class Randn(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG {

def this() = this(UnresolvedSeed, true)

Expand Down Expand Up @@ -199,29 +201,16 @@ object Randn {
""",
examples = """
Examples:
> SELECT _FUNC_(10, 20) > 0 AS result;
> SELECT _FUNC_(10, 20, 0) > 0 AS result;
true
""",
since = "4.0.0",
group = "math_funcs")
object UniformExpressionBuilder extends ExpressionBuilder {
override def build(funcName: String, expressions: Seq[Expression]): Expression = {
val numArgs = expressions.length
expressions match {
case Seq(min, max) =>
Uniform(min, max)
case Seq(min, max, seed) =>
Uniform(min, max, seed)
case _ =>
throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2, 3), numArgs)
}
}
}

case class Uniform(min: Expression, max: Expression)
extends RuntimeReplaceable with BinaryLike[Expression] with ExpressionWithRandomSeed {
case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
extends RuntimeReplaceable with TernaryLike[Expression] with RDG {
def this(min: Expression, max: Expression) =
this(min, max, Literal(Uniform.random.nextLong(), LongType))

private var seed: Expression = Literal(Uniform.random.nextLong(), LongType)
final override lazy val deterministic: Boolean = false
override val nodePatterns: Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
Expand Down Expand Up @@ -255,7 +244,7 @@ case class Uniform(min: Expression, max: Expression)
def requiredType = "integer or floating-point"
Seq((min, "min", 0),
(max, "max", 1),
(seed, "seed", 2)).foreach {
(seedExpression, "seed", 2)).foreach {
case (expr: Expression, name: String, index: Int) =>
if (!expr.foldable && result == TypeCheckResult.TypeCheckSuccess) {
result = DataTypeMismatch(
Expand All @@ -277,18 +266,16 @@ case class Uniform(min: Expression, max: Expression)
result
}

override def left: Expression = min
override def right: Expression = max
override def first: Expression = min
override def second: Expression = max
override def third: Expression = seedExpression

override def seedExpression: Expression = seed
override def withNewSeed(newSeed: Long): Expression = {
val result = Uniform(min, max)
result.seed = Literal(newSeed, LongType)
result
}
override def withNewSeed(newSeed: Long): Expression =
Uniform(min, max, Literal(newSeed, LongType))

override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
Uniform(newFirst, newSecond)
override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
Uniform(newFirst, newSecond, newThird)

override def replacement: Expression = {
def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to)
Expand All @@ -305,12 +292,6 @@ case class Uniform(min: Expression, max: Expression)

object Uniform {
lazy val random = new Random()

def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform = {
val result = Uniform(min, max)
result.seed = seedExpression
result
}
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@
| org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct<negative(1):int> |
| org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> |
| org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct<decode(unhex(537061726B2053514C), UTF-8):string> |
| org.apache.spark.sql.catalyst.expressions.UniformExpressionBuilder | uniform | SELECT uniform(10, 20) > 0 AS result | struct<result:boolean> |
| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(10, 20, 0) > 0 AS result | struct<result:boolean> |
| org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct<unix_date(1970-01-02):int> |
| org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct<unix_micros(1970-01-01 00:00:01Z):bigint> |
| org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct<unix_millis(1970-01-01 00:00:01Z):bigint> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputType" : "\"VOID\"",
"paramIndex" : "first",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(NULL, 1)\""
"sqlExpr" : "\"uniform(NULL, 1, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -197,7 +197,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputType" : "\"VOID\"",
"paramIndex" : "second",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, NULL)\""
"sqlExpr" : "\"uniform(0, NULL, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -221,7 +221,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputType" : "\"VOID\"",
"paramIndex" : "third",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, 1)\""
"sqlExpr" : "\"uniform(0, 1, NULL)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -244,7 +244,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputExpr" : "\"col\"",
"inputName" : "seed",
"inputType" : "integer or floating-point",
"sqlExpr" : "\"uniform(10, 20)\""
"sqlExpr" : "\"uniform(10, 20, col)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -267,7 +267,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputExpr" : "\"col\"",
"inputName" : "min",
"inputType" : "integer or floating-point",
"sqlExpr" : "\"uniform(col, 10)\""
"sqlExpr" : "\"uniform(col, 10, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/test/resources/sql-tests/results/random.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputType" : "\"VOID\"",
"paramIndex" : "first",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(NULL, 1)\""
"sqlExpr" : "\"uniform(NULL, 1, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -245,7 +245,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputType" : "\"VOID\"",
"paramIndex" : "second",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, NULL)\""
"sqlExpr" : "\"uniform(0, NULL, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -271,7 +271,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputType" : "\"VOID\"",
"paramIndex" : "third",
"requiredType" : "integer or floating-point",
"sqlExpr" : "\"uniform(0, 1)\""
"sqlExpr" : "\"uniform(0, 1, NULL)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -296,7 +296,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputExpr" : "\"col\"",
"inputName" : "seed",
"inputType" : "integer or floating-point",
"sqlExpr" : "\"uniform(10, 20)\""
"sqlExpr" : "\"uniform(10, 20, col)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -321,7 +321,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputExpr" : "\"col\"",
"inputName" : "min",
"inputType" : "integer or floating-point",
"sqlExpr" : "\"uniform(col, 10)\""
"sqlExpr" : "\"uniform(col, 10, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand Down

0 comments on commit 0683318

Please sign in to comment.