Skip to content

Commit

Permalink
Initial Java binding through JNI.
Browse files Browse the repository at this point in the history
  • Loading branch information
mingzhao-db committed Jan 1, 2022
1 parent 51bf023 commit 8c4735c
Show file tree
Hide file tree
Showing 8 changed files with 423 additions and 0 deletions.
38 changes: 38 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,41 @@ python_configure(name = "local_config_python")
register_toolchains("@local_config_python//:toolchain")

tf_configure(name = "local_config_tf")

# Java binding support.
http_archive(
name = "fmeum_rules_jni",
sha256 = "8d685e381cb625e11fac330085de2ebc13ad497d30c4e9b09beb212f7c27e8e7",
url = "https://github.com/fmeum/rules_jni/releases/download/v0.3.0/rules_jni-v0.3.0.tar.gz",
)

load("@fmeum_rules_jni//jni:repositories.bzl", "rules_jni_dependencies")

rules_jni_dependencies()

load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external")

java_import_external(
name = "junit_long",
jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a",
jar_urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
"https://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
"https://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar",
],
licenses = ["reciprocal"], # Common Public License Version 1.0
testonly_ = True,
deps = ["@org_hamcrest_core"],
)

java_import_external(
name = "org_hamcrest_core",
jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9",
jar_urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
"https://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
"https://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
],
licenses = ["notice"], # New BSD License
testonly_ = True,
)
69 changes: 69 additions & 0 deletions java/com/google/riegeli/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
load("@fmeum_rules_jni//jni:defs.bzl", "cc_jni_library", "java_jni_library", "jni_headers")

# Java interface for that will be implemented using JNI later.
java_library(
name = "wrapper",
srcs = [
"RecordReader.java",
"RecordWriter.java",
],
)

# Generate the native header.
jni_headers(
name = "header",
lib = ":wrapper",
)

cc_jni_library(
name = "riegeli_jni",
srcs = [
"jni_record_reader.cc",
"jni_record_writer.cc",
],
visibility = [
],
deps = [
":header",
"//riegeli/bytes:fd_reader",
"//riegeli/bytes:fd_writer",
"//riegeli/records:record_reader",
"//riegeli/records:record_writer",
],
)

java_jni_library(
name = "loader",
srcs = [
"Loader.java",
],
native_libs = [
":riegeli_jni",
],
visibility = [
],
deps = [
":wrapper",
],
)

java_library(
name = "tests",
testonly = 1,
srcs = [
"RecordReadWriteTest.java",
],
deps = [
":loader",
":wrapper",
"@junit_long",
"@org_hamcrest_core",
],
)

java_test(
name = "RecordReadWriteTest",
runtime_deps = [
":tests",
],
)
19 changes: 19 additions & 0 deletions java/com/google/riegeli/Loader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.google.riegeli;

import com.github.fmeum.rules_jni.RulesJni;

public class Loader {
// Not sure whether it's worth the redirection to put JNI native lib loading logic in a single class.
static {
RulesJni.loadLibrary("riegeli_jni", RecordReader.class);
RulesJni.loadLibrary("riegeli_jni", RecordWriter.class);
}

public final static RecordWriter newWriter() {
return new RecordWriter();
}

public static RecordReader newReader() {
return new RecordReader();
}
}
40 changes: 40 additions & 0 deletions java/com/google/riegeli/RecordReadWriteTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.google.riegeli;

import org.junit.Before;
import org.junit.Test;

import java.io.IOException;

import static org.junit.Assert.assertEquals;

public class RecordReadWriteTest {

@Before
public void setUp() {
}

@Test
public void writeWriteString() throws IOException {
// TODO: create a random file on TEST_TEMP directory.
final String filename = "/tmp/test.rg";
RecordWriter writer = Loader.newWriter();
writer.open(filename, "default");
final int kNumRecords = 4096;
for (int i = 0; i < kNumRecords; i++) {
final String s = "a".repeat(i + 1);
writer.writeRecord(s);
}
writer.close();

RecordReader reader = Loader.newReader();
reader.open(filename);
for (int i = 0; i < kNumRecords; i++) {
byte[] record = reader.readRecord();
final String s = "a".repeat(i + 1);
assertEquals(new String(record), s);
}
byte[] record = reader.readRecord();
assertEquals(null, record);
reader.close();
}
}
19 changes: 19 additions & 0 deletions java/com/google/riegeli/RecordReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.google.riegeli;

import java.io.IOException;

// JNI wrapper for riegeli record reader.
public class RecordReader {

public final static class Options {
// Nothing is supported for now.
}

private long recordReaderPtr;

public native void open(String filename) throws IOException;

public native byte[] readRecord();

public native void close() throws IOException;
}
50 changes: 50 additions & 0 deletions java/com/google/riegeli/RecordWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.google.riegeli;

import java.io.IOException;

// JNI wrapper for riegeli record writer.
public class RecordWriter {

public final static class Options {
// Nothing is supported for now.
}

// Pointer to the C++ object.
private long recordWriterPtr;

// Options could be:
// ```
// options ::= option? ("," option?)*
// option ::=
// "default" |
// "transpose" (":" ("true" | "false"))? |
// "uncompressed" |
// "brotli" (":" brotli_level)? |
// "zstd" (":" zstd_level)? |
// "snappy" |
// "window_log" ":" window_log |
// "chunk_size" ":" chunk_size |
// "bucket_fraction" ":" bucket_fraction |
// "pad_to_block_boundary" (":" ("true" | "false"))? |
// "parallelism" ":" parallelism
// brotli_level ::= integer in the range [0..11] (default 6)
// zstd_level ::= integer in the range [-131072..22] (default 3)
// window_log ::= "auto" or integer in the range [10..31]
// chunk_size ::= "auto" or positive integer expressed as real with
// optional suffix [BkKMGTPE]
// bucket_fraction ::= real in the range [0..1]
// parallelism ::= non-negative integer
// ```
public native void open(String filename, String options) throws IOException;

public void writeRecord(String record) throws IOException {
writeRecord(record.getBytes());
}

public native void writeRecord(byte[] record) throws IOException;

// Flush the data into disk, more `writeRecord` can be called.
public native void flush() throws IOException;

public native void close() throws IOException;
}
82 changes: 82 additions & 0 deletions java/com/google/riegeli/jni_record_reader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "com_google_riegeli_RecordReader.h"

#include "riegeli/bytes/fd_reader.h"
#include "riegeli/records/record_reader.h"

#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: com_google_riegeli_RecordReader
* Method: open
* Signature: (Ljava/lang/String;)V
*/
using ReaderType = riegeli::RecordReader<riegeli::FdReader<>>;

JNIEXPORT void JNICALL Java_com_google_riegeli_RecordReader_open(JNIEnv* env, jobject reader, jstring filename) {
/* Obtain a C-copy of the Java string */
const char* fname = env->GetStringUTFChars(filename, nullptr);

/* Create the recorder */
riegeli::RecordReaderBase::Options record_reader_options;
auto* record_reader = new ReaderType(
std::forward_as_tuple(fname, O_RDONLY),
record_reader_options);

env->ReleaseStringUTFChars(filename, fname);

/* Get the Field ID of the instance variables "recordReaderPtr" */
jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J");

// Save the pointer as member.
env->SetLongField(reader, fid, reinterpret_cast<jlong>(record_reader));
}

namespace {
ReaderType* getRecordReader(JNIEnv* env, jobject reader) {
jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J");
jlong ptr = env->GetLongField(reader, fid);
return reinterpret_cast<ReaderType*>(ptr);
}
}

/*
* Class: com_google_riegeli_RecordReader
* Method: readRecord
* Signature: ()[B
*/
JNIEXPORT jbyteArray JNICALL Java_com_google_riegeli_RecordReader_readRecord(JNIEnv* env, jobject obj) {
auto* record_reader = getRecordReader(env, obj);
if (!record_reader) {
return nullptr;
}
std::string record;
if (record_reader->ReadRecord(record)) {
jbyteArray ret = env->NewByteArray(record.size());
env->SetByteArrayRegion(ret, 0, record.size(), reinterpret_cast<const jbyte*>(record.data()));
return ret;
} else {
return nullptr;
}
}

/*
* Class: com_google_riegeli_RecordReader
* Method: close
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_com_google_riegeli_RecordReader_close(JNIEnv* env, jobject reader) {
auto* record_reader = getRecordReader(env, reader);
if (record_reader) {
record_reader->Close();
delete record_reader;

jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J");
env->SetLongField(reader, fid, 0L);
}
}

#ifdef __cplusplus
}
#endif

Loading

0 comments on commit 8c4735c

Please sign in to comment.