Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct weighted summation null handling behavior #5660

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.deephaven.chunk.attributes.Values;
import io.deephaven.chunk.*;
import io.deephaven.util.mutable.MutableInt;
import io.deephaven.util.mutable.MutableLong;

import java.util.Collections;
import java.util.Map;
Expand Down Expand Up @@ -46,12 +47,12 @@ public void addChunk(BucketedContext bucketedContext, Chunk<? extends Values> va
IntChunk<ChunkPositions> startPositions, IntChunk<ChunkLengths> length,
WritableBooleanChunk<Values> stateModified) {
final Context context = (Context) bucketedContext;
final LongChunk<? extends Values> doubleValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getAddedWeights();
Assert.neqNull(weightValues, "weightValues");
for (int ii = 0; ii < startPositions.size(); ++ii) {
final int startPosition = startPositions.get(ii);
stateModified.set(ii, addChunk(doubleValues, weightValues, startPosition, length.get(ii),
stateModified.set(ii, addChunk(longValues, weightValues, startPosition, length.get(ii),
destinations.get(startPosition)));
}
}
Expand All @@ -62,12 +63,12 @@ public void removeChunk(BucketedContext bucketedContext, Chunk<? extends Values>
IntChunk<ChunkPositions> startPositions, IntChunk<ChunkLengths> length,
WritableBooleanChunk<Values> stateModified) {
final Context context = (Context) bucketedContext;
final LongChunk<? extends Values> doubleValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getRemovedWeights();
Assert.neqNull(weightValues, "weightValues");
for (int ii = 0; ii < startPositions.size(); ++ii) {
final int startPosition = startPositions.get(ii);
stateModified.set(ii, removeChunk(doubleValues, weightValues, startPosition, length.get(ii),
stateModified.set(ii, removeChunk(longValues, weightValues, startPosition, length.get(ii),
destinations.get(startPosition)));
}
}
Expand All @@ -93,18 +94,18 @@ public void modifyChunk(BucketedContext bucketedContext, Chunk<? extends Values>
public boolean addChunk(SingletonContext singletonContext, int chunkSize, Chunk<? extends Values> values,
LongChunk<? extends RowKeys> inputRowKeys, long destination) {
final Context context = (Context) singletonContext;
final LongChunk<? extends Values> doubleValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getAddedWeights();
return addChunk(doubleValues, weightValues, 0, values.size(), destination);
return addChunk(longValues, weightValues, 0, values.size(), destination);
}

@Override
public boolean removeChunk(SingletonContext singletonContext, int chunkSize, Chunk<? extends Values> values,
LongChunk<? extends RowKeys> inputRowKeys, long destination) {
final Context context = (Context) singletonContext;
final LongChunk<? extends Values> doubleValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getRemovedWeights();
return removeChunk(doubleValues, weightValues, 0, values.size(), destination);
return removeChunk(longValues, weightValues, 0, values.size(), destination);
}

@Override
Expand All @@ -121,19 +122,19 @@ public boolean modifyChunk(SingletonContext singletonContext, int chunkSize, Chu
newDoubleValues.size(), destination);
}

private static void sumChunks(LongChunk<? extends Values> doubleValues, LongChunk<? extends Values> weightValues,
private static void sumChunks(LongChunk<? extends Values> longValues, LongChunk<? extends Values> weightValues,
int start,
int length,
MutableInt normalOut,
MutableInt weightedSumOut) {
MutableLong weightedSumOut) {
int normal = 0;
int weightedSum = 0;
long weightedSum = 0;

for (int ii = 0; ii < length; ++ii) {
final double weight = weightValues.get(start + ii);
final double component = doubleValues.get(start + ii);
final long weight = weightValues.get(start + ii);
final long component = longValues.get(start + ii);

if (weight == QueryConstants.NULL_DOUBLE || component == QueryConstants.NULL_DOUBLE) {
if (weight == QueryConstants.NULL_LONG || component == QueryConstants.NULL_LONG) {
continue;
}

Expand All @@ -148,12 +149,12 @@ private static void sumChunks(LongChunk<? extends Values> doubleValues, LongChun
private boolean addChunk(LongChunk<? extends Values> longValues, LongChunk<? extends Values> weightValues,
int start, int length, long destination) {
final MutableInt normalOut = new MutableInt();
final MutableInt weightedSumOut = new MutableInt();
final MutableLong weightedSumOut = new MutableLong();

sumChunks(longValues, weightValues, start, length, normalOut, weightedSumOut);

final int newNormal = normalOut.get();
final int newWeightedSum = weightedSumOut.get();
final long newWeightedSum = weightedSumOut.get();

final long totalNormal;
final long existingNormal = normalCount.getUnsafe(destination);
Expand All @@ -171,21 +172,21 @@ private boolean addChunk(LongChunk<? extends Values> longValues, LongChunk<? ext
weightedSum.set(destination, totalWeightedSum);
}

final double existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
final long existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
return totalWeightedSum != existingResult;
}
return false;
}

private boolean removeChunk(LongChunk<? extends Values> doubleValues, LongChunk<? extends Values> weightValues,
private boolean removeChunk(LongChunk<? extends Values> longValues, LongChunk<? extends Values> weightValues,
int start, int length, long destination) {
final MutableInt normalOut = new MutableInt();
final MutableInt weightedSumOut = new MutableInt();
final MutableLong weightedSumOut = new MutableLong();

sumChunks(doubleValues, weightValues, start, length, normalOut, weightedSumOut);
sumChunks(longValues, weightValues, start, length, normalOut, weightedSumOut);

final int newNormal = normalOut.get();
final int newWeightedSum = weightedSumOut.get();
final long newWeightedSum = weightedSumOut.get();

final long totalNormal;
final long existingNormal = normalCount.getUnsafe(destination);
Expand Down Expand Up @@ -226,17 +227,17 @@ private boolean modifyChunk(LongChunk<? extends Values> prevDoubleValues,
LongChunk<? extends Values> prevWeightValues, LongChunk<? extends Values> newDoubleValues,
LongChunk<? extends Values> newWeightValues, int start, int length, long destination) {
final MutableInt normalOut = new MutableInt();
final MutableInt weightedSumOut = new MutableInt();
final MutableLong weightedSumOut = new MutableLong();

sumChunks(prevDoubleValues, prevWeightValues, start, length, normalOut, weightedSumOut);

final int prevNormal = normalOut.get();
final int prevWeightedSum = weightedSumOut.get();
final long prevWeightedSum = weightedSumOut.get();

sumChunks(newDoubleValues, newWeightValues, start, length, normalOut, weightedSumOut);

final int newNormal = normalOut.get();
final int newWeightedSum = weightedSumOut.get();
final long newWeightedSum = weightedSumOut.get();

final long totalNormal;
final long existingNormal = normalCount.getUnsafe(destination);
Expand All @@ -255,12 +256,12 @@ private boolean modifyChunk(LongChunk<? extends Values> prevDoubleValues,
weightedSum.set(destination, totalWeightedSum);
}

final double existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
final long existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
return totalWeightedSum != existingResult;
} else {
if (prevNormal > 0) {
weightedSum.set(destination, 0L);
resultColumn.set(destination, QueryConstants.NULL_DOUBLE);
resultColumn.set(destination, QueryConstants.NULL_LONG);
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,40 @@ private void testWeightedAvgByIncremental(int size, int seed) {

}

@Test
public void testWeightedSumByLong() {
final QueryTable table = testRefreshingTable(i(2, 4, 6).toTracking(),
col("Long1", 2L, 4L, 6L), col("Long2", 1L, 2L, 3L));
final Table result = table.wsumBy("Long2");
TableTools.show(result);
TestCase.assertEquals(1, result.size());
long result_wsum = result.getColumnSource("Long1", long.class).getLong(result.getRowSet().firstRowKey());
long wsum = 2 + 8 + 18;
TestCase.assertEquals(wsum, result_wsum);

final ControlledUpdateGraph updateGraph = ExecutionContext.getContext().getUpdateGraph().cast();
updateGraph.runWithinUnitTestCycle(() -> {
addToTable(table, i(8), col("Long1", (long) Integer.MAX_VALUE), col("Long2", 7L));
table.notifyListeners(i(8), i(), i());
});
show(result);
result_wsum = result.getColumnSource("Long1", long.class).getLong(result.getRowSet().firstRowKey());
wsum = wsum + (7L * (long) Integer.MAX_VALUE);
TestCase.assertEquals(wsum, result_wsum);
}

@Test
public void testId5522() {
final QueryTable table = testRefreshingTable(i(2, 4, 6).toTracking(),
col("Long1", 10L, 20L, 30L), col("Long2", 1L, NULL_LONG, 1L));
final Table result = table.wsumBy("Long2");
TableTools.show(result);
TestCase.assertEquals(1, result.size());
long result_wsum = result.getColumnSource("Long1", long.class).getLong(result.getRowSet().firstRowKey());
long wsum = 10 + 30;
TestCase.assertEquals(wsum, result_wsum);
}

@Test
public void testWeightedSumByIncremental() {
final int[] sizes = {10, 50, 200};
Expand Down
Loading