Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Only Test][Don't Review] Make FilterExec to support subexpressionElimination in the codegen scenario #49573

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

panbingkun
Copy link
Contributor

@panbingkun panbingkun commented Jan 20, 2025

What changes were proposed in this pull request?

The pr aims to make FilterExec to support subexpressionElimination in the codegen scenario.

Why are the changes needed?

Improve performance.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Pass GA.

Was this patch authored or co-authored using generative AI tooling?

No.

@github-actions github-actions bot added the SQL label Jan 20, 2025
@panbingkun panbingkun changed the title [Only Test][Don't Review] FilterExec Codegen SubexpressionElimination [Only Test][Don't Review] Make FilterExec to support subexpressionElimination in the codegen scenario Jan 22, 2025
@panbingkun
Copy link
Contributor Author

panbingkun commented Jan 23, 2025

  • A bad case

    test("sliding window grouping") {
    val df1 = Seq(
    ("2016-03-27 19:39:34", 1, "a"),
    ("2016-03-27 19:39:56", 2, "a"),
    ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
    val df2 = Seq(
    (LocalDateTime.parse("2016-03-27T19:39:34"), 1, "a"),
    (LocalDateTime.parse("2016-03-27T19:39:56"), 2, "a"),
    (LocalDateTime.parse("2016-03-27T19:39:27"), 4, "b")).toDF("time", "value", "id")
    Seq(df1, df2).foreach { df =>
    checkAnswer(
    df.groupBy(window($"time", "10 seconds", "3 seconds", "0 second"))
    .agg(count("*").as("counts"))
    .orderBy($"window.start".asc)
    .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
    // 2016-03-27 19:39:27 UTC -> 4 bins
    // 2016-03-27 19:39:34 UTC -> 3 bins
    // 2016-03-27 19:39:56 UTC -> 3 bins
    Seq(
    Row("2016-03-27 19:39:18", "2016-03-27 19:39:28", 1),
    Row("2016-03-27 19:39:21", "2016-03-27 19:39:31", 1),
    Row("2016-03-27 19:39:24", "2016-03-27 19:39:34", 1),
    Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 2),
    Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
    Row("2016-03-27 19:39:33", "2016-03-27 19:39:43", 1),
    Row("2016-03-27 19:39:48", "2016-03-27 19:39:58", 1),
    Row("2016-03-27 19:39:51", "2016-03-27 19:40:01", 1),
    Row("2016-03-27 19:39:54", "2016-03-27 19:40:04", 1))
    )
    }
    }

  • codegen code (when it triggers the split expression)

/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean hashAgg_initAgg_0;
/* 010 */   private org.apache.spark.unsafe.KVIterator hashAgg_mapIter_0;
/* 011 */   private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap hashAgg_hashMap_0;
/* 012 */   private org.apache.spark.sql.execution.UnsafeKVExternalSorter hashAgg_sorter_0;
/* 013 */   private scala.collection.Iterator localtablescan_input_0;
/* 014 */   private boolean expand_resultIsNull_0;
/* 015 */   private long filter_subExprValue_0;
/* 016 */   private boolean filter_subExprIsNull_0;
/* 017 */   private long filter_subExprValue_1;
/* 018 */   private boolean filter_subExprIsNull_1;
/* 019 */   private long filter_subExprValue_2;
/* 020 */   private boolean filter_subExprIsNull_2;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[15];
/* 022 */   private InternalRow[] expand_mutableStateArray_0 = new InternalRow[1];
/* 023 */
/* 024 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 025 */     this.references = references;
/* 026 */   }
/* 027 */
/* 028 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 029 */     partitionIndex = index;
/* 030 */     this.inputs = inputs;
/* 031 */     wholestagecodegen_init_0_0();
/* 032 */     wholestagecodegen_init_0_1();
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void wholestagecodegen_init_0_1() {
/* 037 */     filter_mutableStateArray_0[8] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_mutableStateArray_0[7], 2);
/* 038 */     filter_mutableStateArray_0[9] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 039 */     filter_mutableStateArray_0[10] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_mutableStateArray_0[9], 2);
/* 040 */     filter_mutableStateArray_0[11] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 041 */     filter_mutableStateArray_0[12] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_mutableStateArray_0[11], 2);
/* 042 */     filter_mutableStateArray_0[13] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 32);
/* 043 */     filter_mutableStateArray_0[14] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_mutableStateArray_0[13], 2);
/* 044 */
/* 045 */   }
/* 046 */
/* 047 */   private void wholestagecodegen_init_0_0() {
/* 048 */     localtablescan_input_0 = inputs[0];
/* 049 */     filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 64);
/* 050 */     filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 051 */     filter_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 052 */     expand_resultIsNull_0 = true;
/* 053 */     expand_mutableStateArray_0[0] = null;
/* 054 */     filter_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 64);
/* 055 */     filter_mutableStateArray_0[4] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_mutableStateArray_0[3], 2);
/* 056 */     filter_mutableStateArray_0[5] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 64);
/* 057 */     filter_mutableStateArray_0[6] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_mutableStateArray_0[5], 2);
/* 058 */     filter_mutableStateArray_0[7] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
/* 059 */
/* 060 */   }
/* 061 */
/* 062 */   private void filter_subExpr_0(boolean expand_resultIsNull_0, org.apache.spark.sql.catalyst.InternalRow expand_mutableStateArray_0[0]) {
/* 063 */     // 1...
/* 064 */     boolean filter_isNull_11 = expand_resultIsNull_0;
/* 065 */     long filter_value_12 = -1L;
/* 066 */
/* 067 */     if (!expand_resultIsNull_0) {
/* 068 */       if (expand_mutableStateArray_0[0].isNullAt(0)) {
/* 069 */         filter_isNull_11 = true;
/* 070 */       } else {
/* 071 */         filter_value_12 = expand_mutableStateArray_0[0].getLong(0);
/* 072 */       }
/* 073 */
/* 074 */     }
/* 075 */     // 2...
/* 076 */     filter_subExprIsNull_0 = filter_isNull_11;
/* 077 */     // 3...
/* 078 */     filter_subExprValue_0 = filter_value_12;
/* 079 */   }
/* 080 */
/* 081 */   protected void processNext() throws java.io.IOException {
/* 082 */     if (!hashAgg_initAgg_0) {
/* 083 */       hashAgg_initAgg_0 = true;
  • CodeGenerator##getLocalInputVariableValues
    (It seems that there is no good support for array type variables.)
image
  • If I adjust the spark.sql.codegen.methodSplitThreshold value to be larger, this case can pass (I understand that it directly uses global variables instead of splitting multiple functions)
    val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant