Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48719][SQL] Fix the calculation bug of RegrSlope & RegrIntercept when the first parameter is null #47105

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg

private val covarPop = new CovPopulation(right, left)

private val varPop = new VariancePop(right)
private val varPop = new VariancePop(If(And(IsNotNull(left), IsNotNull(right)),
right, Literal.create(null, DoubleType)))

override def nullable: Boolean = true

Expand Down Expand Up @@ -311,7 +312,8 @@ case class RegrIntercept(left: Expression, right: Expression) extends Declarativ

private val covarPop = new CovPopulation(right, left)

private val varPop = new VariancePop(right)
private val varPop = new VariancePop(If(And(IsNotNull(left), IsNotNull(right)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding null check to the underlying VariancePop, shall we add it to RegrIntercept? e.g. we can add null check to updateExpressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried changing it locally and related test cases passed. It seems that the logic is more complicated. Is it worth updating like this? WDYT @cloud-fan

  override lazy val updateExpressions: Seq[Expression] = {
    val isNull = left.isNull || right.isNull
    val updateResult = covarPop.updateExpressions ++ varPop.updateExpressions
    aggBufferAttributes.zip(updateResult).map { case (oldValue, newValue) =>
      If(isNull, oldValue, newValue)
    }
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more efficient, as we do null check earlier, before we execute VariancePop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the value is calculated regardless of whether it is null with code val updateResult = covarPop.updateExpressions ++ varPop.updateExpressions.

It's similar to the updateExpressionsDef func in Covariance. The value is calculated first, and finally different values ​​are returned depending on whether it is null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't found a better way to implement this logic. If can use if else directly, it may not be calculated in advance, but it cannot be expressed directly here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated it, would you like to take another look? Thank you. @cloud-fan

right, Literal.create(null, DoubleType)))

override def nullable: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
AS testRegression(k, y, x)
-- !query analysis
CreateViewCommand `testRegression`, SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
AS testRegression(k, y, x), false, true, LocalTempView, UNSUPPORTED, true
+- Project [k#x, y#x, x#x]
+- SubqueryAlias testRegression
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-- Test data.
CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
AS testRegression(k, y, x);

-- SPARK-37613: Support ANSI Aggregate Function: regr_count
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tuple is added with the value of y is null, it should be filtered out during calculation, so the output related to RegrSlope & RegrIntercept in the output remains unchanged.

AS testRegression(k, y, x)
-- !query schema
struct<>
Expand Down Expand Up @@ -31,7 +31,7 @@ SELECT k, count(*), regr_count(y, x) FROM testRegression GROUP BY k
struct<k:int,count(1):bigint,regr_count(y, x):bigint>
-- !query output
1 1 0
2 4 3
2 5 3


-- !query
Expand All @@ -40,7 +40,7 @@ SELECT k, count(*) FILTER (WHERE x IS NOT NULL), regr_count(y, x) FROM testRegre
struct<k:int,count(1) FILTER (WHERE (x IS NOT NULL)):bigint,regr_count(y, x):bigint>
-- !query output
1 0 0
2 3 3
2 4 3


-- !query
Expand Down Expand Up @@ -99,7 +99,7 @@ SELECT k, avg(x), avg(y), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression G
struct<k:int,avg(x):double,avg(y):double,regr_avgx(y, x):double,regr_avgy(y, x):double>
-- !query output
1 NULL 10.0 NULL NULL
2 22.666666666666668 21.25 22.666666666666668 20.0
2 27.0 21.25 22.666666666666668 20.0


-- !query
Expand All @@ -116,15 +116,15 @@ SELECT regr_sxx(y, x) FROM testRegression
-- !query schema
struct<regr_sxx(y, x):double>
-- !query output
288.66666666666663
288.6666666666667


-- !query
SELECT regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL
-- !query schema
struct<regr_sxx(y, x):double>
-- !query output
288.66666666666663
288.6666666666667


-- !query
Expand All @@ -133,15 +133,15 @@ SELECT k, regr_sxx(y, x) FROM testRegression GROUP BY k
struct<k:int,regr_sxx(y, x):double>
-- !query output
1 NULL
2 288.66666666666663
2 288.6666666666667


-- !query
SELECT k, regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k
-- !query schema
struct<k:int,regr_sxx(y, x):double>
-- !query output
2 288.66666666666663
2 288.6666666666667


-- !query
Expand Down Expand Up @@ -215,15 +215,15 @@ SELECT regr_slope(y, x) FROM testRegression
-- !query schema
struct<regr_slope(y, x):double>
-- !query output
0.8314087759815244
0.8314087759815242


-- !query
SELECT regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL
-- !query schema
struct<regr_slope(y, x):double>
-- !query output
0.8314087759815244
0.8314087759815242


-- !query
Expand All @@ -232,15 +232,15 @@ SELECT k, regr_slope(y, x) FROM testRegression GROUP BY k
struct<k:int,regr_slope(y, x):double>
-- !query output
1 NULL
2 0.8314087759815244
2 0.8314087759815242


-- !query
SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k
-- !query schema
struct<k:int,regr_slope(y, x):double>
-- !query output
2 0.8314087759815244
2 0.8314087759815242


-- !query
Expand Down