Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Slice 2.0 #18471

Merged
merged 4 commits into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@
<classifier>jdk11</classifier>
</dependency>

<dependency>
<groupId>com.teradata</groupId>
<artifactId>re2j-td</artifactId>
</dependency>

<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
Expand Down Expand Up @@ -252,6 +247,11 @@
<artifactId>opentelemetry-context</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>re2j</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-array</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,76 @@ public void readBytes(byte[] destination, int destinationIndex, int length)
}
}

@Override
public void readShorts(short[] destination, int destinationIndex, int length)
{
ReadBuffer buffer = buffers[0];
int shortsRemaining = length;
while (shortsRemaining > 0) {
ensureReadable(min(Long.BYTES, shortsRemaining * Short.BYTES));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this line. If there are, say, 100 shorts remaining, this would do:

ensureReadable(min(8, 200)) -> ensureReadable(8)

Also, what's the significance of Long.BYTES here? Why would this method care about having at least 8 readable bytes, assuming that's the intention?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely agree. This is the style of the other methods in this class I had to ask @arhimondr how this code works, and this is what he said:

Hmm, it does look weird. I need to read it carefully to tell you more.
It basically tells that "at least" 8 bytes (or less if remaining bytes is less than 8" should be available
If not available - it will read an entire encryption / compression block

I think we all agree it need some renaming

int shortsToRead = min(shortsRemaining, buffer.available() / Short.BYTES);
buffer.readShorts(destination, destinationIndex, shortsToRead);
shortsRemaining -= shortsToRead;
destinationIndex += shortsToRead;
}
}

@Override
public void readInts(int[] destination, int destinationIndex, int length)
{
ReadBuffer buffer = buffers[0];
int intsRemaining = length;
while (intsRemaining > 0) {
ensureReadable(min(Long.BYTES, intsRemaining * Integer.BYTES));
int intsToRead = min(intsRemaining, buffer.available() / Integer.BYTES);
buffer.readInts(destination, destinationIndex, intsToRead);
intsRemaining -= intsToRead;
destinationIndex += intsToRead;
}
}

@Override
public void readLongs(long[] destination, int destinationIndex, int length)
{
ReadBuffer buffer = buffers[0];
int longsRemaining = length;
while (longsRemaining > 0) {
ensureReadable(min(Long.BYTES, longsRemaining * Long.BYTES));
int longsToRead = min(longsRemaining, buffer.available() / Long.BYTES);
buffer.readLongs(destination, destinationIndex, longsToRead);
longsRemaining -= longsToRead;
destinationIndex += longsToRead;
}
}

@Override
public void readFloats(float[] destination, int destinationIndex, int length)
{
ReadBuffer buffer = buffers[0];
int floatsRemaining = length;
while (floatsRemaining > 0) {
ensureReadable(min(Long.BYTES, floatsRemaining * Float.BYTES));
int floatsToRead = min(floatsRemaining, buffer.available() / Float.BYTES);
buffer.readFloats(destination, destinationIndex, floatsToRead);
floatsRemaining -= floatsToRead;
destinationIndex += floatsToRead;
}
}

@Override
public void readDoubles(double[] destination, int destinationIndex, int length)
{
ReadBuffer buffer = buffers[0];
int doublesRemaining = length;
while (doublesRemaining > 0) {
ensureReadable(min(Long.BYTES, doublesRemaining * Double.BYTES));
int doublesToRead = min(doublesRemaining, buffer.available() / Double.BYTES);
buffer.readDoubles(destination, destinationIndex, doublesToRead);
doublesRemaining -= doublesToRead;
destinationIndex += doublesToRead;
}
}

@Override
public void readBytes(Slice destination, int destinationIndex, int length)
{
Expand Down Expand Up @@ -469,7 +539,6 @@ private static class ReadBuffer
public ReadBuffer(Slice slice)
{
requireNonNull(slice, "slice is null");
checkArgument(slice.hasByteArray(), "slice is expected to be based on a byte array");
this.slice = slice;
limit = slice.length();
}
Expand Down Expand Up @@ -572,6 +641,36 @@ public void readBytes(byte[] destination, int destinationIndex, int length)
position += length;
}

public void readShorts(short[] destination, int destinationIndex, int length)
{
slice.getShorts(position, destination, destinationIndex, length);
position += length * Short.BYTES;
}

public void readInts(int[] destination, int destinationIndex, int length)
{
slice.getInts(position, destination, destinationIndex, length);
position += length * Integer.BYTES;
}

public void readLongs(long[] destination, int destinationIndex, int length)
{
slice.getLongs(position, destination, destinationIndex, length);
position += length * Long.BYTES;
}

public void readFloats(float[] destination, int destinationIndex, int length)
{
slice.getFloats(position, destination, destinationIndex, length);
position += length * Float.BYTES;
}

public void readDoubles(double[] destination, int destinationIndex, int length)
{
slice.getDoubles(position, destination, destinationIndex, length);
position += length * Double.BYTES;
}

public void readBytes(Slice destination, int destinationIndex, int length)
{
slice.getBytes(position, destination, destinationIndex, length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,91 @@ public void writeBytes(byte[] source, int sourceIndex, int length)
uncompressedSize += length;
}

@Override
public void writeShorts(short[] source, int sourceIndex, int length)
{
WriteBuffer buffer = buffers[0];
int currentIndex = sourceIndex;
int shortsRemaining = length;
while (shortsRemaining > 0) {
ensureCapacityFor(min(Long.BYTES, shortsRemaining * Short.BYTES));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

int bufferCapacity = buffer.remainingCapacity();
int shortsToCopy = min(shortsRemaining, bufferCapacity / Short.BYTES);
buffer.writeShorts(source, currentIndex, shortsToCopy);
currentIndex += shortsToCopy;
shortsRemaining -= shortsToCopy;
}
uncompressedSize += length * Short.BYTES;
}

@Override
public void writeInts(int[] source, int sourceIndex, int length)
{
WriteBuffer buffer = buffers[0];
int currentIndex = sourceIndex;
int intsRemaining = length;
while (intsRemaining > 0) {
ensureCapacityFor(min(Long.BYTES, intsRemaining * Integer.BYTES));
int bufferCapacity = buffer.remainingCapacity();
int intsToCopy = min(intsRemaining, bufferCapacity / Integer.BYTES);
buffer.writeInts(source, currentIndex, intsToCopy);
currentIndex += intsToCopy;
intsRemaining -= intsToCopy;
}
uncompressedSize += length * Integer.BYTES;
}

@Override
public void writeLongs(long[] source, int sourceIndex, int length)
{
WriteBuffer buffer = buffers[0];
int currentIndex = sourceIndex;
int longsRemaining = length;
while (longsRemaining > 0) {
ensureCapacityFor(min(Long.BYTES, longsRemaining * Long.BYTES));
int bufferCapacity = buffer.remainingCapacity();
int longsToCopy = min(longsRemaining, bufferCapacity / Long.BYTES);
buffer.writeLongs(source, currentIndex, longsToCopy);
currentIndex += longsToCopy;
longsRemaining -= longsToCopy;
}
uncompressedSize += length * Long.BYTES;
}

@Override
public void writeFloats(float[] source, int sourceIndex, int length)
{
WriteBuffer buffer = buffers[0];
int currentIndex = sourceIndex;
int floatsRemaining = length;
while (floatsRemaining > 0) {
ensureCapacityFor(min(Long.BYTES, floatsRemaining * Float.BYTES));
int bufferCapacity = buffer.remainingCapacity();
int floatsToCopy = min(floatsRemaining, bufferCapacity / Float.BYTES);
buffer.writeFloats(source, currentIndex, floatsToCopy);
currentIndex += floatsToCopy;
floatsRemaining -= floatsToCopy;
}
uncompressedSize += length * Float.BYTES;
}

@Override
public void writeDoubles(double[] source, int sourceIndex, int length)
{
WriteBuffer buffer = buffers[0];
int currentIndex = sourceIndex;
int doublesRemaining = length;
while (doublesRemaining > 0) {
ensureCapacityFor(min(Long.BYTES, doublesRemaining * Double.BYTES));
int bufferCapacity = buffer.remainingCapacity();
int doublesToCopy = min(doublesRemaining, bufferCapacity / Double.BYTES);
buffer.writeDoubles(source, currentIndex, doublesToCopy);
currentIndex += doublesToCopy;
doublesRemaining -= doublesToCopy;
}
uncompressedSize += length * Double.BYTES;
}

public Slice closePage()
{
compress();
Expand Down Expand Up @@ -589,6 +674,36 @@ public void writeBytes(byte[] source, int sourceIndex, int length)
position += length;
}

public void writeShorts(short[] source, int sourceIndex, int length)
{
slice.setShorts(position, source, sourceIndex, length);
position += length * Short.BYTES;
}

public void writeInts(int[] source, int sourceIndex, int length)
{
slice.setInts(position, source, sourceIndex, length);
position += length * Integer.BYTES;
}

public void writeLongs(long[] source, int sourceIndex, int length)
{
slice.setLongs(position, source, sourceIndex, length);
position += length * Long.BYTES;
}

public void writeFloats(float[] source, int sourceIndex, int length)
{
slice.setFloats(position, source, sourceIndex, length);
position += length * Float.BYTES;
}

public void writeDoubles(double[] source, int sourceIndex, int length)
{
slice.setDoubles(position, source, sourceIndex, length);
position += length * Double.BYTES;
}

public void skip(int length)
{
position += length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.io.ByteStreams.readFully;
import static io.airlift.slice.UnsafeSlice.getIntUnchecked;
import static io.trino.block.BlockSerdeUtil.readBlock;
import static io.trino.block.BlockSerdeUtil.writeBlock;
import static io.trino.execution.buffer.PageCodecMarker.COMPRESSED;
Expand Down Expand Up @@ -217,7 +216,7 @@ public static Slice readSerializedPage(Slice headerSlice, InputStream inputStrea
{
checkArgument(headerSlice.length() == SERIALIZED_PAGE_HEADER_SIZE, "headerSlice length should equal to %s", SERIALIZED_PAGE_HEADER_SIZE);

int compressedSize = getIntUnchecked(headerSlice, SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET);
int compressedSize = headerSlice.getIntUnchecked(SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET);
byte[] outputBuffer = new byte[SERIALIZED_PAGE_HEADER_SIZE + compressedSize];
headerSlice.getBytes(0, outputBuffer, 0, SERIALIZED_PAGE_HEADER_SIZE);
readFully(inputStream, outputBuffer, SERIALIZED_PAGE_HEADER_SIZE, compressedSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.operator.aggregation.state.NullableLongState;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
Expand All @@ -34,7 +36,6 @@

import java.lang.invoke.MethodHandle;

import static io.airlift.slice.Slices.wrappedLongArray;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
Expand Down Expand Up @@ -89,7 +90,9 @@ public static void output(
out.appendNull();
}
else {
VARBINARY.writeSlice(out, wrappedLongArray(state.getValue()));
Slice value = Slices.allocate(Long.BYTES);
value.setLong(0, state.getValue());
VARBINARY.writeSlice(out, value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public NumericHistogram(Slice serialized, int buffer)
values = new double[maxBuckets + buffer];
weights = new double[maxBuckets + buffer];

input.readBytes(Slices.wrappedDoubleArray(values), nextIndex * SizeOf.SIZE_OF_DOUBLE);
input.readBytes(Slices.wrappedDoubleArray(weights), nextIndex * SizeOf.SIZE_OF_DOUBLE);
input.readDoubles(values, 0, nextIndex);
input.readDoubles(weights, 0, nextIndex);
}

public Slice serialize()
Expand All @@ -90,8 +90,8 @@ public Slice serialize()
.appendByte(FORMAT_TAG)
.appendInt(maxBuckets)
.appendInt(nextIndex)
.appendBytes(Slices.wrappedDoubleArray(values, 0, nextIndex))
.appendBytes(Slices.wrappedDoubleArray(weights, 0, nextIndex))
.appendDoubles(values, 0, nextIndex)
.appendDoubles(weights, 0, nextIndex)
.getUnderlyingSlice();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder ou
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
long[] buffer = new long[4];
Slice buffer = Slices.allocate(Long.BYTES * 4);
long high = decimal[offset];
long low = decimal[offset + 1];

buffer[0] = low;
buffer[1] = high;
buffer.setLong(0, low);
buffer.setLong(Long.BYTES, high);
// if high = 0, the count will overwrite it
int countOffset = 1 + (high == 0 ? 0 : 1);
// append count, overflow
buffer[countOffset] = count;
buffer[countOffset + 1] = overflow;
buffer.setLong(Long.BYTES * countOffset, count);
buffer.setLong(Long.BYTES * (countOffset + 1), overflow);

// cases
// high == 0 (countOffset = 1)
Expand All @@ -59,7 +59,7 @@ public void serialize(LongDecimalWithOverflowAndLongState state, BlockBuilder ou
// overflow == 0 & count == 1 -> bufferLength = 2
// overflow != 0 || count != 1 -> bufferLength = 4
int bufferLength = countOffset + ((overflow == 0 & count == 1) ? 0 : 2);
VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength));
VARBINARY.writeSlice(out, buffer, 0, bufferLength * Long.BYTES);
}
else {
out.appendNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ public void serialize(LongDecimalWithOverflowState state, BlockBuilder out)
long overflow = state.getOverflow();
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();
long[] buffer = new long[3];
Slice buffer = Slices.allocate(Long.BYTES * 3);
long low = decimal[offset + 1];
long high = decimal[offset];
buffer[0] = low;
buffer[1] = high;
buffer[2] = overflow;
buffer.setLong(0, low);
buffer.setLong(Long.BYTES, high);
buffer.setLong(Long.BYTES * 2, overflow);
// if high == 0 and overflow == 0 we only write low (bufferLength = 1)
// if high != 0 and overflow == 0 we write both low and high (bufferLength = 2)
// if overflow != 0 we write all values (bufferLength = 3)
int decimalsCount = 1 + (high == 0 ? 0 : 1);
int bufferLength = overflow == 0 ? decimalsCount : 3;
VARBINARY.writeSlice(out, Slices.wrappedLongArray(buffer, 0, bufferLength));
VARBINARY.writeSlice(out, buffer, 0, bufferLength * Long.BYTES);
}
else {
out.appendNull();
Expand Down
Loading