Skip to content

Commit

Permalink
refactoring of aggregation code
Browse files Browse the repository at this point in the history
  • Loading branch information
vmzakharov committed Mar 20, 2024
1 parent a5a472a commit 71b7943
Show file tree
Hide file tree
Showing 13 changed files with 156 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

public abstract class AggregateFunction
{
private final String columnName;
private final String sourceColumnName;
private final String targetColumnName;

public AggregateFunction(String newSourceColumnName)
Expand All @@ -26,7 +26,7 @@ public AggregateFunction(String newSourceColumnName)

public AggregateFunction(String newSourceColumnName, String newTargetColumnName)
{
this.columnName = newSourceColumnName;
this.sourceColumnName = newSourceColumnName;
this.targetColumnName = newTargetColumnName;
}

Expand Down Expand Up @@ -105,9 +105,21 @@ public AggregateFunction cloneWith(String newSourceColumnName, String newTargetC
}
}

/**
* @deprecated use <code>getSourceColumnName()</code> instead
* @return the name of the aggregation source column, i.e. the column the values of which will be aggregated
*/
public String getColumnName()
{
return this.columnName;
return this.sourceColumnName;
}

/**
* @return the name of the aggregation source column, i.e. the column the values of which will be aggregated
*/
public String getSourceColumnName()
{
return this.sourceColumnName;
}

public String getTargetColumnName()
Expand All @@ -127,33 +139,49 @@ public boolean supportsSourceType(ValueType type)
return this.supportedSourceTypes().contains(type);
}

public Object applyToColumn(DfColumn column)
{
switch (column.getType())
{
case LONG:
return this.applyToLongColumn((DfLongColumn) column);
case DOUBLE:
return this.applyToDoubleColumn((DfDoubleColumn) column);
case INT:
return this.applyToIntColumn((DfIntColumn) column);
default:
return this.applyToObjectColumn((DfObjectColumn<?>) column);
}
}

public Object applyToDoubleColumn(DfDoubleColumn doubleColumn)
{
throw this.notApplicable("double values");
throw this.notApplicable(doubleColumn);
}

public Object applyToLongColumn(DfLongColumn longColumn)
{
throw this.notApplicable("long values");
throw this.notApplicable(longColumn);
}

public Object applyToIntColumn(DfIntColumn longColumn)
public Object applyToIntColumn(DfIntColumn intColumn)
{
throw this.notApplicable("int values");
throw this.notApplicable(intColumn);
}

public Object applyToObjectColumn(DfObjectColumn<?> objectColumn)
{
throw this.notApplicable("non-numeric values");
throw this.notApplicable(objectColumn);
}

protected RuntimeException notApplicable(String scope)
protected RuntimeException notApplicable(DfColumn column)
{
return exceptionByKey("AGG_NOT_APPLICABLE")
return exceptionByKey("AGG_COL_TYPE_UNSUPPORTED")
.with("operation", this.getName())
.with("operationDescription", this.getDescription())
.with("operationScope", scope)
.getUnsupported();
.with("columnName", column.getName())
.with("columnType", column.getType().toString().toLowerCase())
.get();
}

public int intInitialValue()
Expand Down Expand Up @@ -239,25 +267,9 @@ public String getDescription()
return this.getName();
}

// TODO - refactor default*IfEmpty to have a single method
public Object defaultObjectIfEmpty()
{
throw this.notApplicable("empty lists");
}

public long defaultLongIfEmpty()
{
throw this.notApplicable("empty lists");
}

public int defaultIntIfEmpty()
{
throw this.notApplicable("empty lists");
}

public double defaultDoubleIfEmpty()
public Object valueForEmptyColumn(DfColumn column)
{
throw this.notApplicable("empty lists");
throw this.notApplicable(column);
}

public long getLongValue(DfColumn sourceColumn, int sourceRowIndex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,13 +813,13 @@ public DataFrame pivot(
pivotColumnValue -> {
aggregators.forEach(
aggregator -> {
DfColumn valueColum = this.getColumnNamed(aggregator.getColumnName());
DfColumn valueColum = this.getColumnNamed(aggregator.getSourceColumnName());
ValueType targetType = aggregator.targetColumnType(valueColum.getType());
String targetColumnName = singleAggregator ? pivotColumnValue : pivotColumnValue + ":" + aggregator.getTargetColumnName();

pivoted.addColumn(targetColumnName, targetType);

AggregateFunction aggregatorForPivotValue = aggregator.cloneWith(aggregator.getColumnName(), targetColumnName);
AggregateFunction aggregatorForPivotValue = aggregator.cloneWith(aggregator.getSourceColumnName(), targetColumnName);

pivotColumnNames.add(targetColumnName);
aggregatorsForPivot.add(aggregatorForPivotValue);
Expand Down Expand Up @@ -855,7 +855,7 @@ public DataFrame pivot(
aggregatorsByPivotValue
.get(pivotValue)
.forEach(agg -> {
DfColumn valueColumn = this.getColumnNamed(agg.getColumnName());
DfColumn valueColumn = this.getColumnNamed(agg.getSourceColumnName());
inputRowCountPerAggregateRow.get(agg.getTargetColumnName())[accumulatorRowIndex]++;
pivoted.getColumnNamed(agg.getTargetColumnName())
.applyAggregator(accumulatorRowIndex, valueColumn, finalRowIndex, agg);
Expand All @@ -875,7 +875,7 @@ public DataFrame pivot(
*/
public DataFrame aggregate(ListIterable<AggregateFunction> aggregators)
{
ListIterable<DfColumn> columnsToAggregate = this.getColumnsToAggregate(aggregators.collect(AggregateFunction::getColumnName));
ListIterable<DfColumn> columnsToAggregate = this.getColumnsToAggregate(aggregators.collect(AggregateFunction::getSourceColumnName));

DataFrame summedDataFrame = new DataFrame("Aggregate Of " + this.getName());

Expand Down Expand Up @@ -944,7 +944,7 @@ private DataFrame aggregateByWithIndex(

int[] inputRowCountPerAggregateRow = new int[this.rowCount()]; // sizing for the worst case scenario: no aggregation

ListIterable<String> columnsToAggregateNames = aggregators.collect(AggregateFunction::getColumnName);
ListIterable<String> columnsToAggregateNames = aggregators.collect(AggregateFunction::getSourceColumnName);
ListIterable<DfColumn> columnsToAggregate = this.getColumnsToAggregate(columnsToAggregateNames);

DataFrame aggregatedDataFrame = new DataFrame("Aggregate Of " + this.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,30 @@ default boolean isComputed()

void addEmptyValue();

default Object aggregate(AggregateFunction aggregator)
default Object aggregate(AggregateFunction aggregateFunction)
{
throw exceptionByKey("DF_COL_UNSUPPORTED_AGG")
.with("aggregationName", aggregator.getName())
.with("aggregationDescription", aggregator.getDescription())
.with("columnName", this.getName())
.with("columnType", this.getType())
.get();
if (aggregateFunction.supportsSourceType(this.getType()))
{
if (this.getSize() == 0)
{
return aggregateFunction.valueForEmptyColumn(this);
}

try
{
return aggregateFunction.applyToColumn(this);
}
catch (NullPointerException npe)
{
// npe can be thrown if there is a null value stored in a column of primitive type, this can happen when
// converting column values to a list.
return null;
}
}
else
{
throw aggregateFunction.notApplicable(this);
}
}

void applyAggregator(int targetRowIndex, DfColumn sourceColumn, int sourceRowIndex, AggregateFunction aggregateFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,6 @@ public DfColumn copyTo(DataFrame target)

protected abstract void addAllItemsFrom(DfDoubleColumn doubleColumn);

@Override
public Object aggregate(AggregateFunction aggregateFunction)
{
if (this.getSize() == 0)
{
return aggregateFunction.defaultDoubleIfEmpty();
}

try
{
return aggregateFunction.applyToDoubleColumn(this);
}
catch (NullPointerException npe)
{
return null;
}
}

@Override
public DfCellComparator columnComparator(DfColumn otherColumn)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,6 @@ public DfColumn copyTo(DataFrame target)

protected abstract void addAllItemsFrom(DfIntColumn items);

@Override
public Object aggregate(AggregateFunction aggregateFunction)
{
if (this.getSize() == 0)
{
return aggregateFunction.defaultIntIfEmpty();
}

try
{
return aggregateFunction.applyToIntColumn(this);
}
catch (NullPointerException npe)
{
return null;
}
}

@Override
public DfCellComparator columnComparator(DfColumn otherColumn)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,6 @@ public DfColumn copyTo(DataFrame target)

protected abstract void addAllItemsFrom(DfLongColumn items);

@Override
public Object aggregate(AggregateFunction aggregateFunction)
{
if (this.getSize() == 0)
{
return aggregateFunction.defaultLongIfEmpty();
}

try
{
return aggregateFunction.applyToLongColumn(this);
}
catch (NullPointerException npe)
{
return null;
}
}

@Override
public DfCellComparator columnComparator(DfColumn otherColumn)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,6 @@ default Object getObject(int rowIndex)
return this.getTypedObject(rowIndex);
}

@Override
default Object aggregate(AggregateFunction aggregateFunction)
{
if (aggregateFunction.supportsSourceType(this.getType()))
{
if (this.getSize() == 0)
{
return aggregateFunction.defaultObjectIfEmpty();
}

return aggregateFunction.applyToObjectColumn(this);
}
else
{
throw aggregateFunction.notApplicable("values of type " + this.getType());
}
}

default <IV> IV injectIntoBreakOnNulls(
IV injectedValue,
Function2<IV, T, IV> function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,7 @@ public long longInitialValue()
}

@Override
public long defaultLongIfEmpty()
{
return 0L;
}

@Override
public int defaultIntIfEmpty()
{
return 0;
}

@Override
public double defaultDoubleIfEmpty()
{
return 0.0;
}

@Override
public Object defaultObjectIfEmpty()
public Object valueForEmptyColumn(DfColumn column)
{
return 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,8 @@ protected Object objectAccumulator(Object currentAggregate, Object newValue)
}

@Override
public long defaultLongIfEmpty()
{
return 0L;
}

@Override
public int defaultIntIfEmpty()
public Object valueForEmptyColumn(DfColumn column)
{
return 0;
}

@Override
public double defaultDoubleIfEmpty()
{
return 0.0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static void initialize()
return;
}

addMessage("AGG_NOT_APPLICABLE", "Aggregation '${operation}' (${operationDescription}) cannot be performed on ${operationScope}");
addMessage("AGG_COL_TYPE_UNSUPPORTED", "Aggregation '${operation}' (${operationDescription}) cannot be performed on column ${columnName} of type ${columnType}");
addMessage("AGG_NO_INITIAL_VALUE", "Aggregation '${operation}' does not have a ${type} initial value");
addMessage("AGG_NO_ACCUMULATOR", "Aggregation '${operation}' does not support a ${type} accumulator");
addMessage("AGG_CANNOT_CLONE", "Cannot create a clone of aggregation '${operation}'");
Expand All @@ -38,7 +38,7 @@ public static void initialize()
addMessage("DF_CALC_COL_MODIFICATION", "Cannot directly modify computed column '${columnName}'");
addMessage("DF_CALC_COL_INFER_TYPE", "Cannot add calculated column ${columnName} to data frame ${dataFrameName}: failed to infer the expression type of '${expression}'\n${errorList}");
addMessage("DF_MERGE_COL_DIFF_TYPES", "Attempting to merge columns of different types: ${firstColumnName} (${firstColumnType}) and ${secondColumnName} (${secondColumnType})");
addMessage("DF_COL_UNSUPPORTED_AGG", "Aggregation ${aggregatorName} (${aggregationDescription}) cannot be performed on column ${columnName} of type ${columnType}");
addMessage("DF_COL_CONTAINS_NULL", "Column '${columnName}' contains null value in row ${rowIndex}");
addMessage("CSV_FILE_WRITE_FAIL", "Failed to write data frame to '${fileName}'");
addMessage("CSV_UNSUPPORTED_VAL_TO_STR", "Do not know how to convert value of type ${valueType} to a string");
addMessage("CSV_INFER_SCHEMA_FAIL", "Failed to infer schema from '${fileName}'");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public void testClone()
AggregateFunction sum = AggregateFunction.sum("Foo", "Bar");
AggregateFunction cloned = sum.cloneWith("Baz", "Qux");

assertEquals("Baz", cloned.getColumnName());
assertEquals("Baz", cloned.getSourceColumnName());
assertEquals("Qux", cloned.getTargetColumnName());
assertEquals(sum.getClass(), cloned.getClass());
}
Expand Down
Loading

0 comments on commit 71b7943

Please sign in to comment.