Skip to content

Commit

Permalink
[REMOTE-SHUFFLE-52] Update shuffle write to not encode/decode paramet…
Browse files Browse the repository at this point in the history
…ers (#55)

Signed-off-by: jiafu zhang <jiafu.zhang@intel.com>
  • Loading branch information
jiafuzha authored Jan 7, 2022
1 parent 350067d commit a1fddc1
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ protected ByteBuf readFromDaos() throws IOException {
}

private void releaseDescSet() {
descSet.forEach(desc -> desc.release());
descSet.clear();
descSet.forEach(desc -> desc.discard());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ public void close(boolean force) {
e = new IllegalStateException(sb.toString());
}
readyList.forEach(desc -> desc.release());
runningDescSet.forEach(desc -> desc.release());
runningDescSet.forEach(desc -> desc.discard()); // to be released when poll
if (currentDesc != null) {
currentDesc.release();
currentDesc = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.daos.BufferAllocator;
import io.daos.obj.DaosObject;
import io.daos.obj.IODataDescSync;
import io.daos.obj.IODescUpdAsync;
import io.daos.obj.IOSimpleDDAsync;
import io.netty.buffer.ByteBuf;
import org.apache.spark.SparkConf;
Expand Down Expand Up @@ -141,6 +142,17 @@ public interface DaosWriter {
*/
List<SpillInfo> getSpillInfo(int partitionId);

interface ObjectCache<T> {

T get();

T newObject();

void put(T object);

boolean isFull();
}

/**
* Write parameters, including mapId, shuffleId, number of partitions and write config.
*/
Expand Down Expand Up @@ -345,16 +357,17 @@ public List<IODataDescSync> createUpdateDescs(boolean fullBufferOnly) throws IOE
}

/**
* create list of {@link IOSimpleDDAsync} each of them has only one akey entry.
* create list of {@link IODescUpdAsync} each of them has only one akey entry.
* DAOS has a constraint that same akey cannot be referenced twice in one IO.
*
* @param eqHandle
* @return list of {@link IOSimpleDDAsync}
* @return list of {@link IODescUpdAsync}
* @throws IOException
*/
public List<IOSimpleDDAsync> createUpdateDescAsyncs(long eqHandle) throws IOException {
public List<IODescUpdAsync> createUpdateDescAsyncs(long eqHandle, ObjectCache<IODescUpdAsync> cache)
throws IOException {
// make sure each spilled data don't span multiple mapId_<seq>s.
return createUpdateDescAsyncs(eqHandle, true);
return createUpdateDescAsyncs(eqHandle, cache, true);
}

/**
Expand All @@ -367,21 +380,31 @@ public List<IOSimpleDDAsync> createUpdateDescAsyncs(long eqHandle) throws IOExce
* @return list of {@link IOSimpleDDAsync}
* @throws IOException
*/
public List<IOSimpleDDAsync> createUpdateDescAsyncs(long eqHandle, boolean fullBufferOnly) throws IOException {
public List<IODescUpdAsync> createUpdateDescAsyncs(long eqHandle, ObjectCache<IODescUpdAsync> cache,
boolean fullBufferOnly) throws IOException {
int nbrOfBuf = bufList.size();
if ((nbrOfBuf == 0) | (fullBufferOnly & (nbrOfBuf <= 1))) {
return Collections.emptyList();
}
nbrOfBuf -= fullBufferOnly ? 1 : 0;

List<IOSimpleDDAsync> descList = new ArrayList<>(nbrOfBuf);
List<IODescUpdAsync> descList = new ArrayList<>(nbrOfBuf);
String cmapId = currentMapId();
long bufSize = 0;
long offset = needSpill ? 0 : totalSize;
for (int i = 0; i < nbrOfBuf; i++) {
IOSimpleDDAsync desc = object.createAsyncDataDescForUpdate(partitionIdKey, eqHandle);
IODescUpdAsync desc;
ByteBuf buf = bufList.get(i);
desc.addEntryForUpdate(cmapId, offset + bufSize, buf);
if (!cache.isFull()) {
desc = cache.get();
desc.reuse();
desc.setDkey(partitionIdKey);
desc.setAkey(cmapId);
desc.setOffset(offset + bufSize);
desc.setDataBuffer(buf);
} else {
desc = new IODescUpdAsync(partitionIdKey, cmapId, offset + bufSize, buf);
}
bufSize += buf.readableBytes();
descList.add(desc);
}
Expand Down Expand Up @@ -485,6 +508,7 @@ class WriterConfig {
private int totalSubmittedLimit;
private int threads;
private boolean fromOtherThreads;
private int ioDescCaches;
private SparkConf conf;

private static final Logger logger = LoggerFactory.getLogger(WriterConfig.class);
Expand All @@ -504,6 +528,7 @@ class WriterConfig {
totalInMemSize = (long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MAX_BYTES_IN_FLIGHT()) * 1024;
totalSubmittedLimit = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SUBMITTED_LIMIT());
threads = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_THREADS());
ioDescCaches = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_ASYNC_DESC_CACHES());
fromOtherThreads = (boolean)conf
.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_IN_OTHER_THREAD());
if (logger.isDebugEnabled()) {
Expand Down Expand Up @@ -544,6 +569,10 @@ public int getThreads() {
return threads;
}

public int getIoDescCaches() {
return ioDescCaches;
}

public boolean isFromOtherThreads() {
return fromOtherThreads;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
import io.daos.DaosEventQueue;
import io.daos.TimedOutException;
import io.daos.obj.DaosObject;
import io.daos.obj.IOSimpleDDAsync;
import io.daos.obj.IODescUpdAsync;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
Expand All @@ -41,15 +42,18 @@ public class DaosWriterAsync extends DaosWriterBase {

private DaosEventQueue eq;

private Set<IOSimpleDDAsync> descSet = new LinkedHashSet<>();
private Set<IODescUpdAsync> descSet = new LinkedHashSet<>();

private List<DaosEventQueue.Attachment> completedList = new LinkedList<>();

private AsyncDescCache cache;

private static Logger log = LoggerFactory.getLogger(DaosWriterAsync.class);

public DaosWriterAsync(DaosObject object, WriteParam param) throws IOException {
super(object, param);
eq = DaosEventQueue.getInstance(0);
cache = new AsyncDescCache(param.getConfig().getIoDescCaches());
}

@Override
Expand All @@ -58,7 +62,7 @@ public void flush(int partitionId) throws IOException {
if (buffer == null) {
return;
}
List<IOSimpleDDAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl());
List<IODescUpdAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache);
flush(buffer, descList);
}

Expand All @@ -68,24 +72,33 @@ public void flushAll(int partitionId) throws IOException {
if (buffer == null) {
return;
}
List<IOSimpleDDAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), false);
List<IODescUpdAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache, false);
flush(buffer, descList);
}

private void flush(NativeBuffer buffer, List<IOSimpleDDAsync> descList) throws IOException {
private void cacheOrRelease(IODescUpdAsync desc) {
if (desc.isReusable()) {
cache.put(desc);
} else {
desc.release();
}
}

private void flush(NativeBuffer buffer, List<IODescUpdAsync> descList) throws IOException {
if (!descList.isEmpty()) {
assert Thread.currentThread().getId() == eq.getThreadId() : "current thread " + Thread.currentThread().getId() +
"(" + Thread.currentThread().getName() + "), is not expected " + eq.getThreadId() + "(" +
eq.getThreadName() + ")";

for (IOSimpleDDAsync desc : descList) {
for (IODescUpdAsync desc : descList) {
DaosEventQueue.Event event = acquireEvent();
descSet.add(desc);
desc.setEvent(event);
try {
object.updateAsync(desc);
} catch (Exception e) {
desc.release();
cacheOrRelease(desc);
desc.discard();
descSet.remove(desc);
throw e;
}
Expand All @@ -104,7 +117,7 @@ public void flushAll() throws IOException {
if (buffer == null) {
continue;
}
List<IOSimpleDDAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), false);
List<IODescUpdAsync> descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache, false);
flush(buffer, descList);
}
waitCompletion();
Expand All @@ -116,19 +129,16 @@ protected void waitCompletion() throws IOException {
try {
long dur;
long start = System.currentTimeMillis();
while ((left=descSet.size()) > 0 & ((dur = System.currentTimeMillis() - start) < config.getWaitTimeMs())) {
while ((left = descSet.size()) > 0 & ((dur = System.currentTimeMillis() - start) < config.getWaitTimeMs())) {
completedList.clear();
eq.pollCompleted(completedList, IOSimpleDDAsync.class, descSet, left, config.getWaitTimeMs() - dur);
eq.pollCompleted(completedList, IODescUpdAsync.class, descSet, left, config.getWaitTimeMs() - dur);
verifyCompleted();
}
if (!descSet.isEmpty()) {
throw new TimedOutException("timed out after " + (System.currentTimeMillis() - start));
}
} catch (IOException e) {
throw new IllegalStateException("failed to complete all running updates. ", e);
} finally {
descSet.forEach(desc -> desc.release());
descSet.clear();
}
super.flushAll();
}
Expand All @@ -137,7 +147,7 @@ private DaosEventQueue.Event acquireEvent() throws IOException {
completedList.clear();
try {
DaosEventQueue.Event event = eq.acquireEventBlocking(config.getWaitTimeMs(), completedList,
IOSimpleDDAsync.class, descSet);
IODescUpdAsync.class, descSet);
verifyCompleted();
return event;
} catch (IOException e) {
Expand All @@ -147,11 +157,11 @@ private DaosEventQueue.Event acquireEvent() throws IOException {
}

private void verifyCompleted() throws IOException {
IOSimpleDDAsync failed = null;
IODescUpdAsync failed = null;
int failedCnt = 0;
for (DaosEventQueue.Attachment attachment : completedList) {
descSet.remove(attachment);
IOSimpleDDAsync desc = (IOSimpleDDAsync) attachment;
IODescUpdAsync desc = (IODescUpdAsync) attachment;
if (!desc.isSucceeded()) {
failedCnt++;
if (failed == null) {
Expand All @@ -162,12 +172,12 @@ private void verifyCompleted() throws IOException {
if (log.isDebugEnabled()) {
log.debug("written desc: " + desc);
}
desc.release();
cacheOrRelease(desc);
}
if (failedCnt > 0) {
IOException e = new IOException("failed to write " + failedCnt + " IOSimpleDDAsync. Return code is " +
failed.getReturnCode() + ". First failed is " + failed);
failed.release();
cacheOrRelease(failed);
throw e;
}
}
Expand All @@ -183,11 +193,83 @@ public void close() {
completedList.clear();
completedList = null;
}

if (descSet.isEmpty()) { // all descs polled
cache.release();
} else {
descSet.forEach(d -> d.discard()); // to be released when poll
cache.release(descSet);
descSet.clear();
}
super.close();
}

public void setWriterMap(Map<DaosWriter, Integer> writerMap) {
writerMap.put(this, 0);
this.writerMap = writerMap;
}

static class AsyncDescCache implements ObjectCache<IODescUpdAsync> {
private int idx;
private int total;
private IODescUpdAsync[] array;

public AsyncDescCache(int maxNbr) {
this.array = new IODescUpdAsync[maxNbr];
}

@Override
public IODescUpdAsync get() {
if (idx < total) {
return array[idx++];
}
if (idx < array.length) {
array[idx] = newObject();
total++;
return array[idx++];
}
throw new IllegalStateException("cache is full, " + total);
}

@Override
public IODescUpdAsync newObject() {
return new IODescUpdAsync(32);
}

@Override
public void put(IODescUpdAsync desc) {
if (idx <= 0) {
throw new IllegalStateException("more than actual number of IODescUpdAsyncs put back");
}
if (desc.isDiscarded()) {
desc.release();
desc = newObject();
}
array[--idx] = desc;
}

@Override
public boolean isFull() {
return total == array.length;
}

public void release() {
release(Collections.emptySet());
}

private void release(Set<IODescUpdAsync> filterSet) {
for (int i = 0; i < Math.min(total, array.length); i++) {
IODescUpdAsync desc = array[i];
if (desc != null && !filterSet.contains(desc)) {
desc.release();
}
}
array = null;
idx = 0;
}

protected int getIdx() {
return idx;
}
}
}
Loading

0 comments on commit a1fddc1

Please sign in to comment.