Skip to content

Commit

Permalink
hold a global plasma client instance (apache#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeyuqiang authored Jan 4, 2021
1 parent 4a3a04e commit 808d96d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 45 deletions.
26 changes: 26 additions & 0 deletions core/src/main/java/org/apache/spark/io/pmem/MyPlasmaClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import org.apache.arrow.plasma.PlasmaClient;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkEnv;
import org.apache.spark.internal.config.package$;

/**
* Upstream Plasma Client Wrapper.
Expand Down Expand Up @@ -63,3 +65,27 @@ public byte[] toBytes() {
return objectId.getBytes();
}
}

/**
* Hold a global plasma client instance.
*/
class MyPlasmaClientHolder {

private static MyPlasmaClient client;
private static String DEFAULT_STORE_SERVER_SOCKET = "/tmp/plasma";

public static MyPlasmaClient get() {
if (client == null) {
String storeSocketName = SparkEnv.get() == null ? DEFAULT_STORE_SERVER_SOCKET :
SparkEnv.get().conf().get(package$.MODULE$.PLASMA_SERVER_SOCKET());
client = new MyPlasmaClient(storeSocketName);
}
return client;
}

public static void close() {
client.finalize();
client = null;
}

}
20 changes: 4 additions & 16 deletions core/src/main/java/org/apache/spark/io/pmem/PlasmaInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,25 @@ public class PlasmaInputStream extends InputStream {
* Make sure the given buffer size for input stream is equal to the output stream's
*
* @param parentObjectId parent object id
* @param storeSocketName plasma object store socket
* @param bufferSize buffer size
*/
public PlasmaInputStream(String parentObjectId, String storeSocketName, int bufferSize) {
public PlasmaInputStream(String parentObjectId, int bufferSize) {
this.bufferSize = bufferSize;
this.buffer = ByteBuffer.allocate(bufferSize);
buffer.flip();

this.parentObjectId = parentObjectId;
this.client = new MyPlasmaClient(storeSocketName);
this.client = MyPlasmaClientHolder.get();
this.currChildObjectNumber = 0;
}

/**
* Use {@code DEFAULT_BUFFER_SIZE} as buffer size.
*
* @param parentObjectId
* @param storeSocketName
*/
public PlasmaInputStream(String parentObjectId, String storeSocketName) {
this(parentObjectId, storeSocketName, DEFAULT_BUFFER_SIZE);
public PlasmaInputStream(String parentObjectId) {
this(parentObjectId, DEFAULT_BUFFER_SIZE);
}

private boolean refill() {
Expand Down Expand Up @@ -127,14 +125,4 @@ public long skip(long n) {
return n - remaining;
}

@Override
public void close() {
client.finalize();
}

@Override
protected void finalize() {
close();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,25 @@ public class PlasmaOutputStream extends OutputStream {
* Initialize with a customize buffer size.
*
* @param parentObjectId parent object id
* @param storeSocketName plasma object store socket
* @param bufferSize buffer size
*/
public PlasmaOutputStream(String parentObjectId, String storeSocketName, int bufferSize) {
public PlasmaOutputStream(String parentObjectId, int bufferSize) {
if (bufferSize < 0) {
throw new IllegalArgumentException("buffer size can not be a negative number");
}
this.buffer = ByteBuffer.allocate(bufferSize);
this.parentObjectId = parentObjectId;
this.client = new MyPlasmaClient(storeSocketName);
this.client = MyPlasmaClientHolder.get();
this.currChildObjectNumber = 0;
}

/**
* Use {@code DEFAULT_BUFFER_SIZE} as buffer size.
*
* @param parentObjectId
* @param storeSocketName
*/
public PlasmaOutputStream(String parentObjectId, String storeSocketName) {
this(parentObjectId, storeSocketName, DEFAULT_BUFFER_SIZE);
public PlasmaOutputStream(String parentObjectId) {
this(parentObjectId, DEFAULT_BUFFER_SIZE);
}

@Override
Expand Down Expand Up @@ -114,8 +112,4 @@ private byte[] shrinkLastObjBuffer() {
return lastObjBytes;
}

@Override
public void close() {
client.finalize();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1938,4 +1938,9 @@ package object config {
.version("3.0.1")
.booleanConf
.createWithDefault(false)

private[spark] val PLASMA_SERVER_SOCKET = ConfigBuilder("spark.io.plasma.server.socket")
.version("3.1.0")
.stringConf
.createWithDefault("/tmp/plasma")
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.apache.spark.io.pmem;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -11,14 +13,15 @@
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.junit.Assume.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Tests functionality of {@link PlasmaInputStream} and {@link PlasmaOutputStream}
Expand Down Expand Up @@ -51,31 +54,29 @@ public void setUp() {
} catch (InterruptedException ex2) {
ex2.printStackTrace();
}
mockSparkEnv();
}

@Test(expected = NullPointerException.class)
public void testWithNullData() throws IOException {
PlasmaOutputStream pos = new PlasmaOutputStream(
"testWithEmptyData-dummy-id", plasmaStoreSocket);
PlasmaOutputStream pos = new PlasmaOutputStream("testWithEmptyData-dummy-id");
pos.write(null);
}

@Test
public void testSingleWriteRead() {
String blockId = "block_id_" + random.nextInt(10000000);
byte[] bytesWrite = prepareByteBlockToWrite(1);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId, plasmaStoreSocket);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId);
for (byte b : bytesWrite) {
pos.write(b);
}
pos.close();

byte[] bytesRead = new byte[bytesWrite.length];
PlasmaInputStream pis = new PlasmaInputStream(blockId, plasmaStoreSocket);
PlasmaInputStream pis = new PlasmaInputStream(blockId);
for (int i = 0; i < bytesRead.length; i++) {
bytesRead[i] = (byte) pis.read();
}
pis.close();

assertArrayEquals(bytesWrite, bytesRead);
}
Expand All @@ -84,33 +85,29 @@ public void testSingleWriteRead() {
public void testBufferWriteRead() throws IOException {
String blockId = "block_id_" + random.nextInt(10000000);
byte[] bytesWrite = prepareByteBlockToWrite(1);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId, plasmaStoreSocket);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId);
pos.write(bytesWrite);
pos.close();

byte[] bytesRead = new byte[bytesWrite.length];
PlasmaInputStream pis = new PlasmaInputStream(blockId, plasmaStoreSocket);
PlasmaInputStream pis = new PlasmaInputStream(blockId);
pis.read(bytesRead);
pis.close();
assertArrayEquals(bytesWrite, bytesRead);
}

@Test
public void testPartialBlockWriteRead() throws IOException {
String blockId = "block_id_" + random.nextInt(10000000);
byte[] bytesWrite = prepareByteBlockToWrite(2.7);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId, plasmaStoreSocket);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId);
pos.write(bytesWrite);
pos.close();

ByteBuffer bytesRead = ByteBuffer.allocate(bytesWrite.length);
PlasmaInputStream pis = new PlasmaInputStream(blockId, plasmaStoreSocket);
PlasmaInputStream pis = new PlasmaInputStream(blockId);
byte[] buffer = new byte[DEFAULT_BUFFER_SIZE];
int len;
while ((len = pis.read(buffer)) != -1) {
bytesRead.put(buffer, 0, len);
}
pis.close();

assertArrayEquals(bytesWrite, bytesRead.array());
}
Expand All @@ -119,24 +116,23 @@ public void testPartialBlockWriteRead() throws IOException {
public void testMultiBlocksWriteRead() throws IOException {
String blockId = "block_id_" + random.nextInt(10000000);
byte[] bytesWrite = prepareByteBlockToWrite(2);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId, plasmaStoreSocket);
PlasmaOutputStream pos = new PlasmaOutputStream(blockId);
pos.write(bytesWrite);
pos.close();

ByteBuffer bytesRead = ByteBuffer.allocate(bytesWrite.length);
PlasmaInputStream pis = new PlasmaInputStream(blockId, plasmaStoreSocket);
PlasmaInputStream pis = new PlasmaInputStream(blockId);
byte[] buffer = new byte[DEFAULT_BUFFER_SIZE];
while (pis.read(buffer) != -1) {
bytesRead.put(buffer);
}
pis.close();

assertArrayEquals(bytesWrite, bytesRead.array());
}

@After
public void tearDown() {
try {
MyPlasmaClientHolder.close();
stopPlasmaStore();
deletePlasmaSocketFile();
} catch (InterruptedException ex) {
Expand Down Expand Up @@ -210,4 +206,12 @@ private byte[] prepareByteBlockToWrite(double numOfBlock) {
random.nextBytes(bytesToWrite);
return bytesToWrite;
}

private void mockSparkEnv() {
SparkConf conf = new SparkConf();
conf.set("spark.io.plasma.server.socket", plasmaStoreSocket);
SparkEnv mockEnv = mock(SparkEnv.class);
SparkEnv.set(mockEnv);
when(mockEnv.conf()).thenReturn(conf);
}
}

0 comments on commit 808d96d

Please sign in to comment.