Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new constructor to ZstdDictCompress and ZstdDictDecompress #306

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions src/main/java/com/github/luben/zstd/ZstdDictCompress.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,24 @@ public class ZstdDictCompress extends SharedDictBase {
}

private long nativePtr = 0;

private ByteBuffer sharedDict = null;

private int level = Zstd.defaultCompressionLevel();

private native void init(byte[] dict, int dict_offset, int dict_size, int level);

private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level);
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level, int byReference);

private native void free();

/**
* Get the byte buffer that backs this dict, if any, or null if not backed by a byte buffer.
*/
public ByteBuffer getByReferenceBuffer() {
return sharedDict;
}

/**
* Convenience constructor to create a new dictionary for use with fast compress
*
Expand Down Expand Up @@ -59,6 +69,18 @@ public ZstdDictCompress(byte[] dict, int offset, int length, int level) {
* @param level compression level
*/
public ZstdDictCompress(ByteBuffer dict, int level) {
this(dict, level, false);
}

/**
* Create a new dictionary for use with fast compress.
* If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed.
*
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
* @param level compression level
* @param byReference tell the native part to use the byte buffer directly and not copy the data when true.
*/
public ZstdDictCompress(ByteBuffer dict, int level, boolean byReference) {
this.level = level;
int length = dict.limit() - dict.position();
if (!dict.isDirect()) {
Expand All @@ -67,11 +89,14 @@ public ZstdDictCompress(ByteBuffer dict, int level) {
if (length < 0) {
throw new IllegalArgumentException("dict cannot be empty.");
}
initDirect(dict, dict.position(), length, level);
initDirect(dict, dict.position(), length, level, byReference ? 1 : 0);

if (nativePtr == 0L) {
throw new IllegalStateException("ZSTD_createCDict failed");
}
if (byReference) {
sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free.
}
// Ensures that even if ZstdDictCompress is created and published through a race, no thread could observe
// nativePtr == 0.
storeFence();
Expand All @@ -85,7 +110,11 @@ int level() {
@Override
void doClose() {
if (nativePtr != 0) {
free();
if (sharedDict == null) {
free();
} else {
sharedDict = null;
}
nativePtr = 0;
}
}
Expand Down
33 changes: 30 additions & 3 deletions src/main/java/com/github/luben/zstd/ZstdDictDecompress.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@ public class ZstdDictDecompress extends SharedDictBase {

private long nativePtr = 0L;

private ByteBuffer sharedDict = null;

private native void init(byte[] dict, int dict_offset, int dict_size);

private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size);
private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int byReference);

private native void free();

/**
* Get the byte buffer that backs this dict, if any, or null if not backed by a byte buffer.
*/
public ByteBuffer getByReferenceBuffer() {
return sharedDict;
}

/**
* Convenience constructor to create a new dictionary for use with fast decompress
*
Expand Down Expand Up @@ -52,6 +61,17 @@ public ZstdDictDecompress(byte[] dict, int offset, int length) {
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
*/
public ZstdDictDecompress(ByteBuffer dict) {
this(dict, false);
}

/**
* Create a new dictionary for use with fast decompress.
* If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed.
*
* @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer.
* @param byReference tell the native part to use the byte buffer directly and not copy the data when true.
*/
public ZstdDictDecompress(ByteBuffer dict, boolean byReference) {

int length = dict.limit() - dict.position();
if (!dict.isDirect()) {
Expand All @@ -60,11 +80,14 @@ public ZstdDictDecompress(ByteBuffer dict) {
if (length < 0) {
throw new IllegalArgumentException("dict cannot be empty.");
}
initDirect(dict, dict.position(), length);
initDirect(dict, dict.position(), length, byReference ? 1 : 0);

if (nativePtr == 0L) {
throw new IllegalStateException("ZSTD_createDDict failed");
}
if (byReference) {
sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free.
}
// Ensures that even if ZstdDictDecompress is created and published through a race, no thread could observe
// nativePtr == 0.
storeFence();
Expand All @@ -74,7 +97,11 @@ public ZstdDictDecompress(ByteBuffer dict) {
@Override
void doClose() {
if (nativePtr != 0) {
free();
if (sharedDict == null) {
free();
} else {
sharedDict = null;
}
nativePtr = 0;
}
}
Expand Down
22 changes: 16 additions & 6 deletions src/main/native/jni_fast_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_init
/*
* Class: com_github_luben_zstd_ZstdDictCompress
* Method: init
* Signature: (Ljava/nio/ByteBuffer;III)V
* Signature: (Ljava/nio/ByteBuffer;IIII)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_initDirect
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level)
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level, jint byReference)
{
jclass clazz = (*env)->GetObjectClass(env, obj);
compress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
if (NULL == dict) return;
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);
if (NULL == dict_buff) return;
ZSTD_CDict* cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level);
ZSTD_CDict* cdict = NULL;
if (byReference == 0) {
cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level);
} else {
cdict = ZSTD_createCDict_byReference(((char *)dict_buff) + dict_offset, dict_size, level);
}
if (NULL == cdict) return;
(*env)->SetLongField(env, obj, compress_dict, (jlong)(intptr_t) cdict);
}
Expand Down Expand Up @@ -85,17 +90,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_init
/*
* Class: com_github_luben_zstd_ZstdDictDecompress
* Method: initDirect
* Signature: (Ljava/nio/ByteBuffer;II)V
* Signature: (Ljava/nio/ByteBuffer;III)V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_initDirect
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size)
(JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint byReference)
{
jclass clazz = (*env)->GetObjectClass(env, obj);
decompress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
if (NULL == dict) return;
void *dict_buff = (*env)->GetDirectBufferAddress(env, dict);

ZSTD_DDict* ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size);
ZSTD_DDict* ddict = NULL;
if (byReference == 0) {
ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size);
} else {
ddict = ZSTD_createDDict_byReference(((char *)dict_buff) + dict_offset, dict_size);
}

if (NULL == ddict) return;
(*env)->SetLongField(env, obj, decompress_dict, (jlong)(intptr_t) ddict);
Expand Down
7 changes: 4 additions & 3 deletions src/test/scala/ZstdDict.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,18 @@ class ZstdDictSpec extends AnyFlatSpec {
assert(input.toSeq == decompressed.toSeq)
}

it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with legacy $legacy" in {
it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with byReference $legacy" in {
val byReference = legacy // Reuse the variance flag here.
val size = input.length
val inBuf = ByteBuffer.allocateDirect(size)
inBuf.put(input)
inBuf.flip()
val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level)
val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level, byReference)
val compressed = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt);
Zstd.compress(compressed, inBuf, cdict)
compressed.flip()
cdict.close
val ddict = new ZstdDictDecompress(dictInDirectByteBuffer)
val ddict = new ZstdDictDecompress(dictInDirectByteBuffer, byReference)
val decompressed = ByteBuffer.allocateDirect(size)
Zstd.decompress(decompressed, compressed, ddict)
decompressed.flip()
Expand Down