Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REMOTE-SHUFFLE-52] Update shuffle write to not encode/decode parameters #55

Merged
merged 1 commit into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
<packaging>pom</packaging>

<properties>
<scala.version>2.12.10</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<scala.version>2.12.10</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<java.version>1.8</java.version>
<maven.compiler.source>${java.version}</maven.compiler.source>
<maven.compiler.target>${java.version}</maven.compiler.target>
Expand Down Expand Up @@ -214,8 +214,9 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
Expand All @@ -231,7 +232,7 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ protected ByteBuf readFromDaos() throws IOException {
}

private void releaseDescSet() {
descSet.forEach(desc -> desc.release());
descSet.clear();
descSet.forEach(desc -> desc.discard());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ public void close(boolean force) {
e = new IllegalStateException(sb.toString());
}
readyList.forEach(desc -> desc.release());
runningDescSet.forEach(desc -> desc.release());
runningDescSet.forEach(desc -> desc.discard()); // to be released when poll
if (currentDesc != null) {
currentDesc.release();
currentDesc = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.daos.BufferAllocator;
import io.daos.obj.DaosObject;
import io.daos.obj.IODataDescSync;
import io.daos.obj.IODescUpdAsync;
import io.daos.obj.IOSimpleDDAsync;
import io.netty.buffer.ByteBuf;
import org.apache.spark.SparkConf;
Expand Down Expand Up @@ -141,6 +142,17 @@ public interface DaosWriter {
*/
List<SpillInfo> getSpillInfo(int partitionId);

interface ObjectCache<T> {

T get();

T newObject();

void put(T object);

boolean isFull();
}

/**
* Write parameters, including mapId, shuffleId, number of partitions and write config.
*/
Expand Down Expand Up @@ -345,16 +357,17 @@ public List<IODataDescSync> createUpdateDescs(boolean fullBufferOnly) throws IOE
}

/**
* create list of {@link IOSimpleDDAsync} each of them has only one akey entry.
* create list of {@link IODescUpdAsync} each of them has only one akey entry.
* DAOS has a constraint that same akey cannot be referenced twice in one IO.
*
* @param eqHandle
* @return list of {@link IOSimpleDDAsync}
* @return list of {@link IODescUpdAsync}
* @throws IOException
*/
public List<IOSimpleDDAsync> createUpdateDescAsyncs(long eqHandle) throws IOException {
public List<IODescUpdAsync> createUpdateDescAsyncs(long eqHandle, ObjectCache<IODescUpdAsync> cache)
throws IOException {
// make sure each spilled data don't span multiple mapId_<seq>s.
return createUpdateDescAsyncs(eqHandle, true);
return createUpdateDescAsyncs(eqHandle, cache, true);
}

/**
Expand All @@ -367,21 +380,31 @@ public List<IOSimpleDDAsync> createUpdateDescAsyncs(long eqHandle) throws IOExce
* @return list of {@link IOSimpleDDAsync}
* @throws IOException
*/
public List<IOSimpleDDAsync> createUpdateDescAsyncs(long eqHandle, boolean fullBufferOnly) throws IOException {
public List<IODescUpdAsync> createUpdateDescAsyncs(long eqHandle, ObjectCache<IODescUpdAsync> cache,
boolean fullBufferOnly) throws IOException {
int nbrOfBuf = bufList.size();
if ((nbrOfBuf == 0) | (fullBufferOnly & (nbrOfBuf <= 1))) {
return Collections.emptyList();
}
nbrOfBuf -= fullBufferOnly ? 1 : 0;

List<IOSimpleDDAsync> descList = new ArrayList<>(nbrOfBuf);
List<IODescUpdAsync> descList = new ArrayList<>(nbrOfBuf);
String cmapId = currentMapId();
long bufSize = 0;
long offset = needSpill ? 0 : totalSize;
for (int i = 0; i < nbrOfBuf; i++) {
IOSimpleDDAsync desc = object.createAsyncDataDescForUpdate(partitionIdKey, eqHandle);
IODescUpdAsync desc;
ByteBuf buf = bufList.get(i);
desc.addEntryForUpdate(cmapId, offset + bufSize, buf);
if (!cache.isFull()) {
desc = cache.get();
desc.reuse();
desc.setDkey(partitionIdKey);
desc.setAkey(cmapId);
desc.setOffset(offset + bufSize);
desc.setDataBuffer(buf);
} else {
desc = new IODescUpdAsync(partitionIdKey, cmapId, offset + bufSize, buf);
}
bufSize += buf.readableBytes();
descList.add(desc);
}
Expand Down Expand Up @@ -485,6 +508,7 @@ class WriterConfig {
private int totalSubmittedLimit;
private int threads;
private boolean fromOtherThreads;
private int ioDescCaches;
private SparkConf conf;

private static final Logger logger = LoggerFactory.getLogger(WriterConfig.class);
Expand All @@ -504,6 +528,7 @@ class WriterConfig {
totalInMemSize = (long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MAX_BYTES_IN_FLIGHT()) * 1024;
totalSubmittedLimit = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SUBMITTED_LIMIT());
threads = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_THREADS());
ioDescCaches = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_ASYNC_DESC_CACHES());
fromOtherThreads = (boolean)conf
.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_IN_OTHER_THREAD());
if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -544,6 +569,10 @@ public int getThreads() {
return threads;
}

public int getIoDescCaches() {
return ioDescCaches;
}

public boolean isFromOtherThreads() {
return fromOtherThreads;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
import io.daos.DaosEventQueue;
import io.daos.TimedOutException;
import io.daos.obj.DaosObject;
import io.daos.obj.IOSimpleDDAsync;
import io.daos.obj.IODescUpdAsync;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
Expand All @@ -41,15 +42,18 @@ public class DaosWriterAsync extends DaosWriterBase {

private DaosEventQueue eq;

private Set<IOSimpleDDAsync> descSet = new LinkedHashSet<>();
private Set<IODescUpdAsync> descSet = new LinkedHashSet<>();

private List<DaosEventQueue.Attachment> completedList = new LinkedList<>();

private AsyncDescCache cache;

private static Logger log = LoggerFactory.getLogger(DaosWriterAsync.class);

public DaosWriterAsync(DaosObject object, WriteParam param) throws IOException {
super(object, param);
eq = DaosEventQueue.getInstance(0);
cache = new AsyncDescCache(param.getConfig().getIoDescCaches());
}

@Override
Expand All @@ -58,7 +62,7 @@ public void flush(int partitionId) throws IOException {
if (buffer == null) {
return;
}
List<IOSimpleDDAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl());
List<IODescUpdAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache);
flush(buffer, descList);
}

Expand All @@ -68,24 +72,33 @@ public void flushAll(int partitionId) throws IOException {
if (buffer == null) {
return;
}
List<IOSimpleDDAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), false);
List<IODescUpdAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache, false);
flush(buffer, descList);
}

private void flush(NativeBuffer buffer, List<IOSimpleDDAsync> descList) throws IOException {
private void cacheOrRelease(IODescUpdAsync desc) {
if (desc.isReusable()) {
cache.put(desc);
} else {
desc.release();
}
}

private void flush(NativeBuffer buffer, List<IODescUpdAsync> descList) throws IOException {
if (!descList.isEmpty()) {
assert Thread.currentThread().getId() == eq.getThreadId() : "current thread " + Thread.currentThread().getId() +
"(" + Thread.currentThread().getName() + "), is not expected " + eq.getThreadId() + "(" +
eq.getThreadName() + ")";

for (IOSimpleDDAsync desc : descList) {
for (IODescUpdAsync desc : descList) {
DaosEventQueue.Event event = acquireEvent();
descSet.add(desc);
desc.setEvent(event);
try {
object.updateAsync(desc);
} catch (Exception e) {
desc.release();
cacheOrRelease(desc);
desc.discard();
descSet.remove(desc);
throw e;
}
Expand All @@ -104,7 +117,7 @@ public void flushAll() throws IOException {
if (buffer == null) {
continue;
}
List<IOSimpleDDAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), false);
List<IODescUpdAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache, false);
flush(buffer, descList);
}
waitCompletion();
Expand All @@ -116,19 +129,16 @@ protected void waitCompletion() throws IOException {
try {
long dur;
long start = System.currentTimeMillis();
while ((left=descSet.size()) > 0 & ((dur = System.currentTimeMillis() - start) < config.getWaitTimeMs())) {
while ((left = descSet.size()) > 0 & ((dur = System.currentTimeMillis() - start) < config.getWaitTimeMs())) {
completedList.clear();
eq.pollCompleted(completedList, IOSimpleDDAsync.class, descSet, left, config.getWaitTimeMs() - dur);
eq.pollCompleted(completedList, IODescUpdAsync.class, descSet, left, config.getWaitTimeMs() - dur);
verifyCompleted();
}
if (!descSet.isEmpty()) {
throw new TimedOutException("timed out after " + (System.currentTimeMillis() - start));
}
} catch (IOException e) {
throw new IllegalStateException("failed to complete all running updates. ", e);
} finally {
descSet.forEach(desc -> desc.release());
descSet.clear();
}
super.flushAll();
}
Expand All @@ -137,7 +147,7 @@ private DaosEventQueue.Event acquireEvent() throws IOException {
completedList.clear();
try {
DaosEventQueue.Event event = eq.acquireEventBlocking(config.getWaitTimeMs(), completedList,
IOSimpleDDAsync.class, descSet);
IODescUpdAsync.class, descSet);
verifyCompleted();
return event;
} catch (IOException e) {
Expand All @@ -147,11 +157,11 @@ private DaosEventQueue.Event acquireEvent() throws IOException {
}

private void verifyCompleted() throws IOException {
IOSimpleDDAsync failed = null;
IODescUpdAsync failed = null;
int failedCnt = 0;
for (DaosEventQueue.Attachment attachment : completedList) {
descSet.remove(attachment);
IOSimpleDDAsync desc = (IOSimpleDDAsync) attachment;
IODescUpdAsync desc = (IODescUpdAsync) attachment;
if (!desc.isSucceeded()) {
failedCnt++;
if (failed == null) {
Expand All @@ -162,12 +172,12 @@ private void verifyCompleted() throws IOException {
if (log.isDebugEnabled()) {
log.debug("written desc: " + desc);
}
desc.release();
cacheOrRelease(desc);
}
if (failedCnt > 0) {
IOException e = new IOException("failed to write " + failedCnt + " IOSimpleDDAsync. Return code is " +
failed.getReturnCode() + ". First failed is " + failed);
failed.release();
cacheOrRelease(failed);
throw e;
}
}
Expand All @@ -183,11 +193,83 @@ public void close() {
completedList.clear();
completedList = null;
}

if (descSet.isEmpty()) { // all descs polled
cache.release();
} else {
descSet.forEach(d -> d.discard()); // to be released when poll
cache.release(descSet);
descSet.clear();
}
super.close();
}

public void setWriterMap(Map<DaosWriter, Integer> writerMap) {
writerMap.put(this, 0);
this.writerMap = writerMap;
}

static class AsyncDescCache implements ObjectCache<IODescUpdAsync> {
private int idx;
private int total;
private IODescUpdAsync[] array;

public AsyncDescCache(int maxNbr) {
this.array = new IODescUpdAsync[maxNbr];
}

@Override
public IODescUpdAsync get() {
if (idx < total) {
return array[idx++];
}
if (idx < array.length) {
array[idx] = newObject();
total++;
return array[idx++];
}
throw new IllegalStateException("cache is full, " + total);
}

@Override
public IODescUpdAsync newObject() {
return new IODescUpdAsync(32);
}

@Override
public void put(IODescUpdAsync desc) {
if (idx <= 0) {
throw new IllegalStateException("more than actual number of IODescUpdAsyncs put back");
}
if (desc.isDiscarded()) {
desc.release();
desc = newObject();
}
array[--idx] = desc;
}

@Override
public boolean isFull() {
return total == array.length;
}

public void release() {
release(Collections.emptySet());
}

private void release(Set<IODescUpdAsync> filterSet) {
for (int i = 0; i < Math.min(total, array.length); i++) {
IODescUpdAsync desc = array[i];
if (desc != null && !filterSet.contains(desc)) {
desc.release();
}
}
array = null;
idx = 0;
}

protected int getIdx() {
return idx;
}
}
}
Loading