Skip to content

Commit

Permalink
Add a memory-mapped RandomAccessReader using MemorySegment api
Browse files Browse the repository at this point in the history
I'd prefer this to be in `jvector-twenty` module. But `--enable-preview` flag is only
allowed for the Java release version used to compile the code. When building with Java 22,
`--enable-preview` is not allowed on `twenty` module because it builds for Java 20.
  • Loading branch information
mdogan committed Apr 16, 2024
1 parent 70a6df8 commit edf50bb
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 0 deletions.
8 changes: 8 additions & 0 deletions jvector-native/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
</nonFilteredFileExtensions>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.1.2</version>
<configuration>
<skip>false</skip>
</configuration>
</plugin>
</plugins>
</build>
<profiles>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.disk;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.foreign.ValueLayout.OfFloat;
import java.lang.foreign.ValueLayout.OfInt;
import java.lang.foreign.ValueLayout.OfLong;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;

/**
* {@link MemorySegment} based implementation of RandomAccessReader.
* MemorySegmentReader doesn't have 2GB file size limitation of {@link SimpleMappedReader}.
*/
public class MemorySegmentReader implements RandomAccessReader {

private static final OfInt intLayout = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
private static final OfFloat floatLayout = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
private static final OfLong longLayout = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);

private final Arena arena;
private final MemorySegment memory;
private long position = 0;

public MemorySegmentReader(Path path) throws IOException {
arena = Arena.ofShared();
try (var ch = FileChannel.open(path, StandardOpenOption.READ)) {
memory = ch.map(MapMode.READ_ONLY, 0L, ch.size(), arena);
} catch (Exception e) {
arena.close();
throw e;
}
}

private MemorySegmentReader(Arena arena, MemorySegment memory) {
this.arena = arena;
this.memory = memory;
}

@Override
public void seek(long offset) {
this.position = offset;
}

@Override
public long getPosition() {
return position;
}

@Override
public void readFully(float[] buffer) {
MemorySegment.copy(memory, floatLayout, position, buffer, 0, buffer.length);
position += buffer.length * 4L;
}

@Override
public void readFully(byte[] b) {
MemorySegment.copy(memory, ValueLayout.JAVA_BYTE, position, b, 0, b.length);
position += b.length;
}

@Override
public void readFully(ByteBuffer buffer) {
var remaining = buffer.remaining();
var slice = memory.asSlice(position, remaining).asByteBuffer();
buffer.put(slice);
position += remaining;
}

@Override
public void readFully(long[] vector) {
MemorySegment.copy(memory, longLayout, position, vector, 0, vector.length);
position += vector.length * 8L;
}

@Override
public int readInt() {
var k = memory.get(intLayout, position);
position += 4;
return k;
}

@Override
public float readFloat() {
var f = memory.get(floatLayout, position);
position += 4;
return f;
}

@Override
public void read(int[] ints, int offset, int count) {
MemorySegment.copy(memory, intLayout, position, ints, offset, count);
position += count * 4L;
}

@Override
public void read(float[] floats, int offset, int count) {
MemorySegment.copy(memory, floatLayout, position, floats, offset, count);
position += count * 4L;
}

/**
* Loads the contents of the mapped segment into physical memory.
* This is a best-effort mechanism.
*/
public void loadMemory() {
memory.load();
}

@Override
public void close() {
arena.close();
}

public MemorySegmentReader duplicate() {
return new MemorySegmentReader(arena, memory);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.github.jbellis.jvector.disk;

import java.io.IOException;
import java.nio.file.Path;

public class MemorySegmentReaderSupplier implements ReaderSupplier {
private final MemorySegmentReader reader;

public MemorySegmentReaderSupplier(Path path) throws IOException {
reader = new MemorySegmentReader(path);
}

@Override
public RandomAccessReader get() {
return reader.duplicate();
}

@Override
public void close() {
reader.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.disk;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class MemorySegmentReaderTest extends RandomizedTest {

private Path tempFile;

@Before
public void setup() throws IOException {
tempFile = Files.createTempFile(getClass().getSimpleName(), ".data");

try (var out = new DataOutputStream(new FileOutputStream(tempFile.toFile()))) {
out.write(new byte[] {1, 2, 3, 4, 5, 6, 7});
for (int i = 0; i < 5; i++) {
out.writeInt((i + 1) * 19);
}
for (int i = 0; i < 5; i++) {
out.writeLong((i + 1) * 19L);
}
for (int i = 0; i < 5; i++) {
out.writeFloat((i + 1) * 19);
}
}
}

@After
public void tearDown() throws IOException {
Files.deleteIfExists(tempFile);
}

@Test
public void testReader() throws Exception {
try (var r = new MemorySegmentReader(tempFile)) {
verifyReader(r);

// read 2nd time from beginning
verifyReader(r);
}
}

@Test
public void testReaderDuplicate() throws Exception {
try (var r = new MemorySegmentReader(tempFile)) {
for (int i = 0; i < 3; i++) {
var r2 = r.duplicate();
verifyReader(r2);
}
}
}

@Test
public void testReaderClose() throws Exception {
var r = new MemorySegmentReader(tempFile);
var r2 = r.duplicate();

r.close();

try {
r.readInt();
fail("Should have thrown an exception");
} catch (IllegalStateException _) {
}

try {
r2.readInt();
fail("Should have thrown an exception");
} catch (IllegalStateException _) {
}
}

private void verifyReader(MemorySegmentReader r) {
r.seek(0);
var bytes = new byte[7];
r.readFully(bytes);
for (int i = 0; i < bytes.length; i++) {
assertEquals(i + 1, bytes[i]);
}

r.seek(0);
var buff = ByteBuffer.allocate(6);
r.readFully(buff);
for (int i = 0; i < buff.remaining(); i++) {
assertEquals(i + 1, buff.get(i));
}

r.seek(7);
assertEquals(19, r.readInt());

r.seek(7);
var ints = new int[5];
r.read(ints, 0, ints.length);
for (int i = 0; i < ints.length; i++) {
var k = ints[i];
assertEquals((i + 1) * 19, k);
}

r.seek(7 + (4 * 5));
var longs = new long[5];
r.readFully(longs);
for (int i = 0; i < longs.length; i++) {
var l = longs[i];
assertEquals((i + 1) * 19, l);
}

r.seek(7 + (4 * 5) + (8 * 5));
assertEquals(19, r.readFloat(), 0.01);

r.seek(7 + (4 * 5) + (8 * 5));
var floats = new float[5];
r.readFully(floats);
for (int i = 0; i < floats.length; i++) {
var f = floats[i];
assertEquals((i + 1) * 19f, f, 0.01);
}
}
}

0 comments on commit edf50bb

Please sign in to comment.