Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25317][CORE] Avoid perf regression in Murmur3 Hash on UTF8String #22338

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.google.common.primitives.Ints;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.types.UTF8String;

Expand Down Expand Up @@ -59,7 +60,7 @@ public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) {
// This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
int lengthInBytes = Ints.checkedCast(base.size());
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
int h1 = hashBytesByIntBlock(base, seed);
int h1 = hashBytesByIntBlock(base, lengthInBytes, seed);
return fmix(h1, lengthInBytes);
}

Expand All @@ -69,22 +70,27 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
}

public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) {
return hashUnsafeBytesBlock(base, Ints.checkedCast(base.size()), seed);
}

private static int hashUnsafeBytesBlock(MemoryBlock base, int lengthInBytes, int seed) {
// This is not compatible with original and another implementations.
// But remain it for backward compatibility for the components existing before 2.3.
int lengthInBytes = Ints.checkedCast(base.size());
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed);
int h1 = hashBytesByIntBlock(base, lengthAligned, seed);
long offset = base.getBaseOffset();
Object o = base.getBaseObject();
for (int i = lengthAligned; i < lengthInBytes; i++) {
int halfWord = base.getByte(i);
int halfWord = Platform.getByte(o, offset + i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So seems the performance regression is due to the cost of virtual function calls on MemoryBlock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that was my guess too at the beginning, but if you just do this change, performance won't change. Seems reasonable what said by @kiszk about the clue being the size of the javabyte code generated, but needs more investigation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Seems there are more than single cause for this performance regression.

int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
}
return fmix(h1, lengthInBytes);
}

public static int hashUTF8String(UTF8String str, int seed) {
return hashUnsafeBytesBlock(str.getMemoryBlock(), seed);
return hashUnsafeBytesBlock(str.getMemoryBlock(), str.numBytes(), seed);
}

public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
Expand All @@ -101,7 +107,7 @@ public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) {
int lengthInBytes = Ints.checkedCast(base.size());
assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed);
int h1 = hashBytesByIntBlock(base, lengthAligned, seed);
int k1 = 0;
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
k1 ^= (base.getByte(i) & 0xFF) << shift;
Expand All @@ -110,11 +116,10 @@ public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) {
return fmix(h1, lengthInBytes);
}

private static int hashBytesByIntBlock(MemoryBlock base, int seed) {
long lengthInBytes = base.size();
private static int hashBytesByIntBlock(MemoryBlock base, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
for (long i = 0; i < lengthInBytes; i += 4) {
for (int i = 0; i < lengthInBytes; i += 4) {
int halfWord = base.getInt(i);
int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
Expand Down