Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Initial Java binding through JNI. #21

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,41 @@ python_configure(name = "local_config_python")
register_toolchains("@local_config_python//:toolchain")

tf_configure(name = "local_config_tf")

# Java binding support.
http_archive(
name = "fmeum_rules_jni",
sha256 = "8d685e381cb625e11fac330085de2ebc13ad497d30c4e9b09beb212f7c27e8e7",
url = "https://github.com/fmeum/rules_jni/releases/download/v0.3.0/rules_jni-v0.3.0.tar.gz",
)

load("@fmeum_rules_jni//jni:repositories.bzl", "rules_jni_dependencies")

rules_jni_dependencies()

load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external")

java_import_external(
name = "junit_long",
jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a",
jar_urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
"https://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
"https://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar",
],
licenses = ["reciprocal"], # Common Public License Version 1.0
testonly_ = True,
deps = ["@org_hamcrest_core"],
)

java_import_external(
name = "org_hamcrest_core",
jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9",
jar_urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
"https://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
"https://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
],
licenses = ["notice"], # New BSD License
testonly_ = True,
)
69 changes: 69 additions & 0 deletions java/com/google/riegeli/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
load("@fmeum_rules_jni//jni:defs.bzl", "cc_jni_library", "java_jni_library", "jni_headers")

# Java interface for that will be implemented using JNI later.
java_library(
name = "wrapper",
srcs = [
"RecordReader.java",
"RecordWriter.java",
],
)

# Generate the native header.
jni_headers(
name = "header",
lib = ":wrapper",
)

cc_jni_library(
name = "riegeli_jni",
srcs = [
"jni_record_reader.cc",
"jni_record_writer.cc",
],
visibility = [
],
deps = [
":header",
"//riegeli/bytes:fd_reader",
"//riegeli/bytes:fd_writer",
"//riegeli/records:record_reader",
"//riegeli/records:record_writer",
],
)

java_jni_library(
name = "loader",
srcs = [
"Loader.java",
],
native_libs = [
":riegeli_jni",
],
visibility = [
],
deps = [
":wrapper",
],
)

java_library(
name = "tests",
testonly = 1,
srcs = [
"RecordReadWriteTest.java",
],
deps = [
":loader",
":wrapper",
"@junit_long",
"@org_hamcrest_core",
],
)

java_test(
name = "RecordReadWriteTest",
runtime_deps = [
":tests",
],
)
19 changes: 19 additions & 0 deletions java/com/google/riegeli/Loader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.google.riegeli;

import com.github.fmeum.rules_jni.RulesJni;

public class Loader {
// Not sure whether it's worth the redirection to put JNI native lib loading logic in a single class.
static {
RulesJni.loadLibrary("riegeli_jni", RecordReader.class);
RulesJni.loadLibrary("riegeli_jni", RecordWriter.class);
}

public final static RecordWriter newWriter() {
return new RecordWriter();
}

public static RecordReader newReader() {
return new RecordReader();
}
}
46 changes: 46 additions & 0 deletions java/com/google/riegeli/RecordReadWriteTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.google.riegeli;

import org.junit.Before;
import org.junit.Test;

import java.io.IOException;

import static org.junit.Assert.assertEquals;

public class RecordReadWriteTest {

@Before
public void setUp() {
}

private String createTestString(int length) {
// String.repeat is only available from Java 11, so create a
// helper method instead.
return String.join("", java.util.Collections.nCopies(length, "a"));
}

@Test
public void writeWriteString() throws IOException {
// TODO: create a random file on TEST_TEMP directory.
final String filename = "/tmp/test.rg";
RecordWriter writer = Loader.newWriter();
writer.open(filename, "default");
final int kNumRecords = 4096;
for (int i = 0; i < kNumRecords; i++) {
final String s = createTestString(i+1);
writer.writeRecord(s);
}
writer.close();

RecordReader reader = Loader.newReader();
reader.open(filename);
for (int i = 0; i < kNumRecords; i++) {
byte[] record = reader.readRecord();
final String s = createTestString(i+1);
assertEquals(new String(record), s);
}
byte[] record = reader.readRecord();
assertEquals(null, record);
reader.close();
}
}
19 changes: 19 additions & 0 deletions java/com/google/riegeli/RecordReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.google.riegeli;

import java.io.IOException;

// JNI wrapper for riegeli record reader.
public class RecordReader {

public final static class Options {
// Nothing is supported for now.
}

private long recordReaderPtr;

public native void open(String filename) throws IOException;

public native byte[] readRecord();

public native void close() throws IOException;
}
50 changes: 50 additions & 0 deletions java/com/google/riegeli/RecordWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.google.riegeli;

import java.io.IOException;

// JNI wrapper for riegeli record writer.
public class RecordWriter {

public final static class Options {
// Nothing is supported for now.
}

// Pointer to the C++ object.
private long recordWriterPtr;

// Options could be:
// ```
// options ::= option? ("," option?)*
// option ::=
// "default" |
// "transpose" (":" ("true" | "false"))? |
// "uncompressed" |
// "brotli" (":" brotli_level)? |
// "zstd" (":" zstd_level)? |
// "snappy" |
// "window_log" ":" window_log |
// "chunk_size" ":" chunk_size |
// "bucket_fraction" ":" bucket_fraction |
// "pad_to_block_boundary" (":" ("true" | "false"))? |
// "parallelism" ":" parallelism
// brotli_level ::= integer in the range [0..11] (default 6)
// zstd_level ::= integer in the range [-131072..22] (default 3)
// window_log ::= "auto" or integer in the range [10..31]
// chunk_size ::= "auto" or positive integer expressed as real with
// optional suffix [BkKMGTPE]
// bucket_fraction ::= real in the range [0..1]
// parallelism ::= non-negative integer
// ```
public native void open(String filename, String options) throws IOException;

public void writeRecord(String record) throws IOException {
writeRecord(record.getBytes());
}

public native void writeRecord(byte[] record) throws IOException;

// Flush the data into disk, more `writeRecord` can be called.
public native void flush() throws IOException;

public native void close() throws IOException;
}
82 changes: 82 additions & 0 deletions java/com/google/riegeli/jni_record_reader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "com_google_riegeli_RecordReader.h"

#include "riegeli/bytes/fd_reader.h"
#include "riegeli/records/record_reader.h"

#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: com_google_riegeli_RecordReader
* Method: open
* Signature: (Ljava/lang/String;)V
*/
using ReaderType = riegeli::RecordReader<riegeli::FdReader<>>;

JNIEXPORT void JNICALL Java_com_google_riegeli_RecordReader_open(JNIEnv* env, jobject reader, jstring filename) {
/* Obtain a C-copy of the Java string */
const char* fname = env->GetStringUTFChars(filename, nullptr);

/* Create the recorder */
riegeli::RecordReaderBase::Options record_reader_options;
auto* record_reader = new ReaderType(
std::forward_as_tuple(fname, O_RDONLY),
record_reader_options);

env->ReleaseStringUTFChars(filename, fname);

/* Get the Field ID of the instance variables "recordReaderPtr" */
jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J");

// Save the pointer as member.
env->SetLongField(reader, fid, reinterpret_cast<jlong>(record_reader));
}

namespace {
ReaderType* getRecordReader(JNIEnv* env, jobject reader) {
jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J");
jlong ptr = env->GetLongField(reader, fid);
return reinterpret_cast<ReaderType*>(ptr);
}
}

/*
* Class: com_google_riegeli_RecordReader
* Method: readRecord
* Signature: ()[B
*/
JNIEXPORT jbyteArray JNICALL Java_com_google_riegeli_RecordReader_readRecord(JNIEnv* env, jobject obj) {
auto* record_reader = getRecordReader(env, obj);
if (!record_reader) {
return nullptr;
}
std::string record;
if (record_reader->ReadRecord(record)) {
jbyteArray ret = env->NewByteArray(record.size());
env->SetByteArrayRegion(ret, 0, record.size(), reinterpret_cast<const jbyte*>(record.data()));
return ret;
} else {
return nullptr;
}
}

/*
* Class: com_google_riegeli_RecordReader
* Method: close
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_com_google_riegeli_RecordReader_close(JNIEnv* env, jobject reader) {
auto* record_reader = getRecordReader(env, reader);
if (record_reader) {
record_reader->Close();
delete record_reader;

jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J");
env->SetLongField(reader, fid, 0L);
}
}

#ifdef __cplusplus
}
#endif

Loading