diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java index 90754efa3f83..1ffe223e0d86 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostProvider.java @@ -88,13 +88,6 @@ private PlanNodeCostEstimate getGroupCost(GroupReference groupReference) private PlanNodeCostEstimate calculateCumulativeCost(PlanNode node) { - PlanNodeCostEstimate localCosts = costCalculator.calculateCost(node, statsProvider, session, types); - - PlanNodeCostEstimate sourcesCost = node.getSources().stream() - .map(this::getCumulativeCost) - .reduce(PlanNodeCostEstimate.zero(), PlanNodeCostEstimate::add); - - PlanNodeCostEstimate cumulativeCost = localCosts.add(sourcesCost); - return cumulativeCost; + return costCalculator.calculateCost(node, statsProvider, this, session, types); } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java index 2c71bd3a7682..92ecb49e0e7a 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java @@ -31,7 +31,7 @@ public interface CostCalculator { /** - * Calculates non-cumulative cost of a node. + * Calculates cumulative cost of a node. * * @param node The node to compute cost for. * @param stats The stats provider for node's stats and child nodes' stats, to be used if stats are needed to compute cost for the {@code node} @@ -39,6 +39,7 @@ public interface CostCalculator PlanNodeCostEstimate calculateCost( PlanNode node, StatsProvider stats, + CostProvider costs, Session session, TypeProvider types); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java index a50059b8a994..12c0249641df 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java @@ -90,10 +90,16 @@ public CostCalculatorUsingExchanges(IntSupplier numberOfNodes) } @Override - public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, Session session, TypeProvider types) + public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider costs, Session session, TypeProvider types) { CostEstimator costEstimator = new CostEstimator(numberOfNodes.getAsInt(), stats, types); - return node.accept(costEstimator, null); + PlanNodeCostEstimate localCosts = node.accept(costEstimator, null); + + PlanNodeCostEstimate sourcesCost = node.getSources().stream() + .map(costs::getCumulativeCost) + .reduce(PlanNodeCostEstimate.zero(), PlanNodeCostEstimate::add); + + return localCosts.add(sourcesCost); } private static class CostEstimator diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java index 354279fe8b11..5af1a5f2d2fd 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java @@ -65,11 +65,11 @@ public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, IntSu } @Override - public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, Session session, TypeProvider types) + public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider costs, Session session, TypeProvider types) { ExchangeCostEstimator exchangeCostEstimator = new ExchangeCostEstimator(numberOfNodes.getAsInt(), stats, types); PlanNodeCostEstimate estimatedExchangeCost = node.accept(exchangeCostEstimator, null); - return costCalculator.calculateCost(node, stats, session, types).add(estimatedExchangeCost); + return costCalculator.calculateCost(node, stats, costs, session, types).add(estimatedExchangeCost); } private static class ExchangeCostEstimator diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index 230ad442bd51..79885cd06393 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -88,7 +88,6 @@ import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -442,11 +441,7 @@ public PlanNodeCostEstimate getCumulativeCost(PlanNode node) private PlanNodeCostEstimate calculateCumulativeCost(PlanNode node) { - PlanNodeCostEstimate sourcesCost = node.getSources().stream() - .map(this::getCumulativeCost) - .reduce(PlanNodeCostEstimate.zero(), PlanNodeCostEstimate::add); - - return costCalculator.calculateCost(node, statsProvider, session, types).add(sourcesCost); + return costCalculator.calculateCost(node, statsProvider, this, session, types); } } @@ -505,17 +500,15 @@ private PlanNodeCostEstimate calculateCumulativeCost( Function stats, Map types) { - PlanNodeCostEstimate localCost = costCalculator.calculateCost( + StatsProvider statsProvider = planNode -> requireNonNull(stats.apply(planNode), "no stats for node"); + CostProvider costProvider = costs::apply; + return costCalculator.calculateCost( node, - planNode -> requireNonNull(stats.apply(planNode), "no stats for node"), + statsProvider, + costProvider, session, TypeProvider.copyOf(types.entrySet().stream() .collect(ImmutableMap.toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue)))); - - PlanNodeCostEstimate sourcesCost = node.getSources().stream() - .map(source -> requireNonNull(costs.apply(source), format("no cost for source: %s", source.getId()))) - .reduce(PlanNodeCostEstimate.zero(), PlanNodeCostEstimate::add); - return sourcesCost.add(localCost); } private PlanNodeCostEstimate calculateCumulativeCost(PlanNode node, CostCalculator costCalculator, StatsCalculator statsCalculator, Map types)