Skip to content

Commit

Permalink
adds int type support for built-in aggregation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
vmzakharov committed Mar 20, 2024
1 parent 632a69c commit a5a472a
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.github.vmzakharov.ecdataframe.dataframe.DfDoubleColumn;
import io.github.vmzakharov.ecdataframe.dataframe.DfDoubleColumnStored;
import io.github.vmzakharov.ecdataframe.dataframe.DfIntColumn;
import io.github.vmzakharov.ecdataframe.dataframe.DfIntColumnStored;
import io.github.vmzakharov.ecdataframe.dataframe.DfLongColumn;
import io.github.vmzakharov.ecdataframe.dataframe.DfLongColumnStored;
import io.github.vmzakharov.ecdataframe.dataframe.DfObjectColumn;
Expand All @@ -22,10 +23,11 @@

import static io.github.vmzakharov.ecdataframe.dsl.value.ValueType.*;

// TODO - avoid overflows
public class Avg
extends AggregateFunction
{
private static final ListIterable<ValueType> SUPPORTED_TYPES = Lists.immutable.of(LONG, DOUBLE, DECIMAL);
private static final ListIterable<ValueType> SUPPORTED_TYPES = Lists.immutable.of(INT, LONG, DOUBLE, DECIMAL);

public Avg(String newColumnName)
{
Expand Down Expand Up @@ -78,6 +80,12 @@ public Object applyToObjectColumn(DfObjectColumn<?> objectColumn)
return sum == null ? null : sum.divide(BigDecimal.valueOf(objectColumn.getSize()), RoundingMode.HALF_UP);
}

@Override
protected int intAccumulator(int currentAggregate, int newValue)
{
return currentAggregate + newValue;
}

@Override
protected long longAccumulator(long currentAggregate, long newValue)
{
Expand All @@ -96,6 +104,12 @@ protected Object objectAccumulator(Object currentAggregate, Object newValue)
return ((BigDecimal) currentAggregate).add((BigDecimal) newValue);
}

@Override
public int intInitialValue()
{
return 0;
}

@Override
public long longInitialValue()
{
Expand Down Expand Up @@ -137,6 +151,23 @@ public void finishAggregating(DataFrame aggregatedDataFrame, int[] countsByRow)
}
}
}
else if (aggregatedColumn.getType().isInt())
{
DfIntColumnStored longColumn = (DfIntColumnStored) aggregatedColumn;

for (int rowIndex = 0; rowIndex < columnSize; rowIndex++)
{
if (!longColumn.isNull(rowIndex))
{
int aggregateValue = longColumn.getInt(rowIndex);

if (this.zeroContributorCheck(countsByRow[rowIndex], aggregateValue != 0, rowIndex))
{
longColumn.setInt(rowIndex, aggregateValue / countsByRow[rowIndex]);
}
}
}
}
else if (aggregatedColumn.getType().isDouble())
{
DfDoubleColumnStored doubleColumn = (DfDoubleColumnStored) aggregatedColumn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ public Object applyToIntColumn(DfIntColumn intColumn)
return intColumn.toIntList().max();
}

@Override
public int intInitialValue()
{
return Integer.MIN_VALUE;
}

@Override
public long longInitialValue()
{
Expand Down Expand Up @@ -97,6 +103,12 @@ protected double doubleAccumulator(double currentAggregate, double newValue)
return Math.max(currentAggregate, newValue);
}

@Override
protected int intAccumulator(int currentAggregate, int newValue)
{
return Math.max(currentAggregate, newValue);
}

@Override
protected Object objectAccumulator(Object currentAggregate, Object newValue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ public Object applyToObjectColumn(DfObjectColumn<?> objectColumn)
);
}

@Override
public int intInitialValue()
{
return Integer.MAX_VALUE;
}

@Override
public long longInitialValue()
{
Expand Down Expand Up @@ -98,6 +104,12 @@ protected double doubleAccumulator(double currentAggregate, double newValue)
return Math.min(currentAggregate, newValue);
}

@Override
protected int intAccumulator(int currentAggregate, int newValue)
{
return Math.min(currentAggregate, newValue);
}

@Override
protected Object objectAccumulator(Object currentAggregate, Object newValue)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.github.vmzakharov.ecdataframe.dataframe.aggregation;

import io.github.vmzakharov.ecdataframe.dataframe.AggregateFunction;
import io.github.vmzakharov.ecdataframe.dataframe.DfColumn;
import io.github.vmzakharov.ecdataframe.dataframe.DfDecimalColumn;
import io.github.vmzakharov.ecdataframe.dataframe.DfDoubleColumn;
import io.github.vmzakharov.ecdataframe.dataframe.DfIntColumn;
Expand Down Expand Up @@ -91,6 +92,14 @@ public BigDecimal objectInitialValue()
return BigDecimal.ZERO;
}

@Override
public long getLongValue(DfColumn sourceColumn, int sourceRowIndex)
{
return sourceColumn.getType().isLong()
? ((DfLongColumn) sourceColumn).getLong(sourceRowIndex)
: ((DfIntColumn) sourceColumn).getInt(sourceRowIndex);
}

@Override
protected long longAccumulator(long currentAggregate, long newValue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ public void builtInAggregationFunctionNames()
@Test
public void builtInAggregationSupportedTypes()
{
Assert.assertEquals(Lists.immutable.of(LONG, DOUBLE, DECIMAL), avg("NA").supportedSourceTypes());
Assert.assertEquals(Lists.immutable.of(INT, LONG, DOUBLE, DECIMAL), avg("NA").supportedSourceTypes());
Assert.assertEquals(Lists.immutable.of(INT, LONG, DOUBLE, STRING, DATE, DATE_TIME, DECIMAL), count("NA").supportedSourceTypes());
Assert.assertEquals(Lists.immutable.of(INT, LONG, DOUBLE, DECIMAL), max("NA").supportedSourceTypes());
Assert.assertEquals(Lists.immutable.of(INT, LONG, DOUBLE, DECIMAL), min("NA").supportedSourceTypes());
Expand Down
Loading

0 comments on commit a5a472a

Please sign in to comment.