diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/by/AggregationProcessor.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/by/AggregationProcessor.java index 2d4e0908cd7..41b49a74c70 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/by/AggregationProcessor.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/by/AggregationProcessor.java @@ -651,45 +651,8 @@ final void addWeightedAvgOrSumOperator( addOperator(resultOperator, r.source, r.pair.input().name(), weightName); }); } - } - - // ----------------------------------------------------------------------------------------------------------------- - // Standard Aggregations - // ----------------------------------------------------------------------------------------------------------------- - - /** - * Implementation class for conversion from a collection of {@link Aggregation aggregations} to an - * {@link AggregationContext} for standard aggregations. Accumulates state by visiting each aggregation. - */ - private final class NormalConverter extends Converter { - private final QueryCompilerRequestProcessor.BatchProcessor compilationProcessor; - - private NormalConverter( - @NotNull final Table table, - final boolean requireStateChangeRecorder, - @NotNull final String... groupByColumnNames) { - super(table, requireStateChangeRecorder, groupByColumnNames); - this.compilationProcessor = QueryCompilerRequestProcessor.batch(); - } - @Override - AggregationContext build() { - final AggregationContext resultContext = super.build(); - compilationProcessor.compile(); - return resultContext; - } - - // ------------------------------------------------------------------------------------------------------------- - // Aggregation.Visitor - // ------------------------------------------------------------------------------------------------------------- - - @Override - public void visit(@NotNull final Count count) { - addNoInputOperator(new CountAggregationOperator(count.column().name())); - } - - @Override - public void visit(@NotNull final CountWhere countWhere) { + final void addCountWhereOperator(@NotNull CountWhere countWhere) { final WhereFilter[] whereFilters = WhereFilter.fromInternal(countWhere.filter()); final Map inputColumnRecorderMap = new HashMap<>(); @@ -737,6 +700,47 @@ public void visit(@NotNull final CountWhere countWhere) { addOperator(new CountWhereOperator(countWhere.column().name(), whereFilters, recorders, filterRecorders), null, inputColumnNames); } + } + + // ----------------------------------------------------------------------------------------------------------------- + // Standard Aggregations + // ----------------------------------------------------------------------------------------------------------------- + + /** + * Implementation class for conversion from a collection of {@link Aggregation aggregations} to an + * {@link AggregationContext} for standard aggregations. Accumulates state by visiting each aggregation. + */ + private final class NormalConverter extends Converter { + private final QueryCompilerRequestProcessor.BatchProcessor compilationProcessor; + + private NormalConverter( + @NotNull final Table table, + final boolean requireStateChangeRecorder, + @NotNull final String... groupByColumnNames) { + super(table, requireStateChangeRecorder, groupByColumnNames); + this.compilationProcessor = QueryCompilerRequestProcessor.batch(); + } + + @Override + AggregationContext build() { + final AggregationContext resultContext = super.build(); + compilationProcessor.compile(); + return resultContext; + } + + // ------------------------------------------------------------------------------------------------------------- + // Aggregation.Visitor + // ------------------------------------------------------------------------------------------------------------- + + @Override + public void visit(@NotNull final Count count) { + addNoInputOperator(new CountAggregationOperator(count.column().name())); + } + + @Override + public void visit(@NotNull final CountWhere countWhere) { + addCountWhereOperator(countWhere); + } @Override public void visit(@NotNull final FirstRowKey firstRowKey) { @@ -1051,7 +1055,7 @@ public void visit(@NotNull final Count count) { @Override public void visit(@NotNull final CountWhere countWhere) { - addNoInputOperator(new CountAggregationOperator(countWhere.column().name())); + addCountWhereOperator(countWhere); } @Override diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/TestRollup.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/TestRollup.java new file mode 100644 index 00000000000..d50471f4c93 --- /dev/null +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/TestRollup.java @@ -0,0 +1,124 @@ +// +// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +// +package io.deephaven.engine.table.impl; + +import io.deephaven.api.agg.Aggregation; +import io.deephaven.engine.table.Table; +import io.deephaven.engine.table.hierarchical.RollupTable; +import io.deephaven.engine.testutil.*; +import io.deephaven.engine.testutil.generator.*; +import io.deephaven.engine.testutil.testcase.RefreshingTableTestCase; +import io.deephaven.test.types.OutOfBandTest; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.util.*; + +import static io.deephaven.api.agg.Aggregation.*; +import static io.deephaven.engine.testutil.TstUtils.*; + +@Category(OutOfBandTest.class) +public class TestRollup extends RefreshingTableTestCase { + // This is the list of supported aggregations for rollup. These are all using `intCol` as the column to aggregate + // because the re-aggregation logic is effectively the same for all column types. + private final Collection aggs = List.of( + AggAbsSum("absSum=intCol"), + AggAvg("avg=intCol"), + AggCount("count"), + AggCountWhere("countWhere", "intCol > 50"), + AggCountDistinct("countDistinct=intCol"), + AggDistinct("distinct=intCol"), + AggFirst("first=intCol"), + AggLast("last=intCol"), + AggMax("max=intCol"), + AggMin("min=intCol"), + AggSortedFirst("Sym", "firstSorted=intCol"), + AggSortedLast("Sym", "lastSorted=intCol"), + AggStd("std=intCol"), + AggSum("sum=intCol"), + AggUnique("unique=intCol"), + AggVar("var=intCol"), + AggWAvg("intCol", "wavg=intCol"), + AggWSum("intCol", "wsum=intCol")); + + // Companion list of columns to compare between rollup root and the zero-key equivalent + private final String[] columnsToCompare = new String[] { + "absSum", + "avg", + "count", + "countWhere", + "countDistinct", + "distinct", + "first", + "last", + "max", + "min", + "firstSorted", + "lastSorted", + "std", + "sum", + "unique", + "var", + "wavg", + "wsum" + }; + + @SuppressWarnings("rawtypes") + private final ColumnInfo[] columnInfo = initColumnInfos( + new String[] {"Sym", "intCol"}, + new SetGenerator<>("a", "b", "c", "d"), + new IntGenerator(10, 100)); + + private QueryTable createTable(boolean refreshing, int size, Random random) { + return getTable(refreshing, size, random, columnInfo); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + @Test + public void testRollup() { + final Random random = new Random(0); + // Create the test table + final Table testTable = createTable(false, 100_000, random); + + final RollupTable rollupTable = testTable.rollup(aggs, false, "Sym"); + final Table rootTable = rollupTable.getRoot(); + + final Table actual = rootTable.select(columnsToCompare); + final Table expected = testTable.aggBy(aggs); + + // Compare the zero-key equivalent table to the rollup table root + TstUtils.assertTableEquals(actual, expected); + } + + @Test + public void testRollupIncremental() { + for (int size = 10; size <= 1000; size *= 10) { + testRollupIncrementalInternal("size-" + size, size); + } + } + + private void testRollupIncrementalInternal(final String ctxt, final int size) { + final Random random = new Random(0); + + final QueryTable testTable = createTable(true, size * 10, random); + EvalNuggetInterface[] en = new EvalNuggetInterface[] { + new QueryTableTest.TableComparator( + testTable.rollup(aggs, false, "Sym") + .getRoot().select(columnsToCompare), + testTable.aggBy(aggs)) + }; + + final int steps = 100; + for (int step = 0; step < steps; step++) { + if (RefreshingTableTestCase.printTableUpdates) { + System.out.println("Step = " + step); + } + simulateShiftAwareStep(ctxt + " step == " + step, size, random, testTable, columnInfo, en); + } + } +} diff --git a/py/server/tests/test_rollup_tree_table.py b/py/server/tests/test_rollup_tree_table.py index 2f1d76bdafd..409a9319f30 100644 --- a/py/server/tests/test_rollup_tree_table.py +++ b/py/server/tests/test_rollup_tree_table.py @@ -4,7 +4,7 @@ import unittest from deephaven import read_csv, empty_table -from deephaven.agg import sum_, avg, count_, first, last, max_, min_, std, abs_sum, \ +from deephaven.agg import sum_, avg, count_, count_where, first, last, max_, min_, std, abs_sum, \ var from deephaven.filters import Filter from deephaven.table import NodeType @@ -18,6 +18,7 @@ def setUp(self): self.aggs_for_rollup = [ avg(["aggAvg=var"]), count_("aggCount"), + count_where("aggCountWhere", "var > 0"), first(["aggFirst=var"]), last(["aggLast=var"]), max_(["aggMax=var"]),