From 37a3820dd1f524de0c40d3e171d186f0496d63d0 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Wed, 2 Aug 2023 18:50:24 +0530 Subject: [PATCH] Estimate row count when BETWEEN value is an expression --- .../main/java/io/trino/cost/FilterStatsCalculator.java | 4 ++-- .../java/io/trino/cost/TestFilterStatsCalculator.java | 10 ++++++++++ .../sql/presto/tpcds/hive/partitioned/q85.plan.txt | 4 ++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index 076abcc86a4d..20c64af07808 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -314,7 +314,8 @@ protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void @Override protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context) { - if (!(node.getValue() instanceof SymbolReference)) { + SymbolStatsEstimate valueStats = getExpressionStats(node.getValue()); + if (valueStats.isUnknown()) { return PlanNodeStatsEstimate.unknown(); } if (!getExpressionStats(node.getMin()).isSingleValue()) { @@ -324,7 +325,6 @@ protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Voi return PlanNodeStatsEstimate.unknown(); } - SymbolStatsEstimate valueStats = input.getSymbolStatistics(Symbol.from(node.getValue())); Expression lowerBound = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()); Expression upperBound = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java index 23736bed3339..725c85c6165a 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java @@ -608,6 +608,16 @@ public void testBetweenOperatorFilter() .highValue(100.0) .nullsFraction(0.0)); + // Expression as value. CAST from DOUBLE to DECIMAL(7,2) + // Produces row count estimate without updating symbol stats + assertExpression("CAST(x AS DECIMAL(7,2)) BETWEEN CAST(DECIMAL '-2.50' AS DECIMAL(7, 2)) AND CAST(DECIMAL '2.50' AS DECIMAL(7, 2))") + .outputRowsCount(219.726563) + .symbolStats("x", symbolStats -> + symbolStats.distinctValuesCount(xStats.getDistinctValuesCount()) + .lowValue(xStats.getLowValue()) + .highValue(xStats.getHighValue()) + .nullsFraction(xStats.getNullsFraction())); + assertExpression("'a' IN ('a', 'b')").equalTo(standardInputStatistics); assertExpression("'a' IN ('a', 'b', NULL)").equalTo(standardInputStatistics); assertExpression("'a' IN ('b', 'c')").outputRowsCount(0); diff --git a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q85.plan.txt b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q85.plan.txt index 9742ebd40226..487162785ab2 100644 --- a/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q85.plan.txt +++ b/testing/trino-tests/src/test/resources/sql/presto/tpcds/hive/partitioned/q85.plan.txt @@ -36,7 +36,7 @@ local exchange (GATHER, SINGLE, []) scan customer_demographics local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan web_page + scan reason local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan reason + scan web_page