Skip to content

Commit

Permalink
[POAE7-497] PMem extension implementation for UnsafeExternalSorter (a…
Browse files Browse the repository at this point in the history
…pache#39)

* [POAE7-497] add memory manager for PMem

* [POAE7-497] memory spill to PMem for UnsafeExternalSorter
  • Loading branch information
yma11 authored Oct 30, 2020
1 parent 3fdfce3 commit 9d25542
Show file tree
Hide file tree
Showing 19 changed files with 1,224 additions and 28 deletions.
5 changes: 5 additions & 0 deletions common/unsafe/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
</properties>

<dependencies>
<dependency>
<groupId>com.intel</groupId>
<artifactId>oap</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

package org.apache.spark.unsafe.memory;
import com.intel.oap.common.unsafe.PersistentMemoryPlatform;

public class ExtendedMemoryAllocator implements MemoryAllocator{

@Override
public MemoryBlock allocate(long size) throws OutOfMemoryError {
long address = PersistentMemoryPlatform.allocateVolatileMemory(size);
MemoryBlock memoryBlock = new MemoryBlock(null, address, size);

return memoryBlock;
}

@Override
public void free(MemoryBlock memoryBlock) {
assert (memoryBlock.getBaseObject() == null) :
"baseObject not null; are you trying to use the AEP-heap allocator to free on-heap memory?";
PersistentMemoryPlatform.freeMemory(memoryBlock.getBaseOffset());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ public interface MemoryAllocator {
MemoryAllocator UNSAFE = new UnsafeMemoryAllocator();

MemoryAllocator HEAP = new HeapMemoryAllocator();

MemoryAllocator EXTENDED = new ExtendedMemoryAllocator();
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public class MemoryBlock extends MemoryLocation {

private final long length;

/**
* indicates page location on DRAM(0) or extended memory(1)
*/
public int location;

/**
* Optional page number; used when this MemoryBlock represents a page allocated by a
* TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
Expand Down
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
</properties>

<dependencies>
<dependency>
<groupId>com.intel</groupId>
<artifactId>oap</artifactId>
<version>0.9.0</version>
</dependency>
<dependency>
<groupId>com.thoughtworks.paranamer</groupId>
<artifactId>paranamer</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ protected void freePage(MemoryBlock page) {
taskMemoryManager.freePage(page, this);
}

protected void freePMemPage(MemoryBlock page) {
taskMemoryManager.freePMemPage(page, this);
}

/**
* Allocates memory of `size`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,19 @@ public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode());
}

public long acquireExtendedMemory(long required) {
assert(required >= 0);
logger.info("Task {} acquire {} bytes PMem memory.", taskAttemptId, Utils.bytesToString(required));
synchronized (this) {
long got = memoryManager.acquireExtendedMemory(required, taskAttemptId);
return got;
}
}

public void releaseExtendedMemory(long size) {
logger.debug("Task {} release {} PMem space.", taskAttemptId, Utils.bytesToString(size));
memoryManager.releaseExtendedMemory(size, taskAttemptId);
}
/**
* Dump the memory usage of all consumers.
*/
Expand Down Expand Up @@ -316,6 +329,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
return allocatePage(size, consumer);
}
page.pageNumber = pageNumber;
page.location = 0;
pageTable[pageNumber] = page;
if (logger.isTraceEnabled()) {
logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired);
Expand Down Expand Up @@ -350,6 +364,66 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) {
releaseExecutionMemory(pageSize, consumer);
}

/**
* allocate PMem Page only happens when PMem space is guaranteed to be enough using acquireExtendedMemory
* in memory consumers, so there is no need to do the logic check here anymore
* @param size
* @return
*/
public MemoryBlock allocatePMemPage(long size) {
if (size > MAXIMUM_PAGE_SIZE_BYTES) {
throw new TooLargePageException(size);
}
final int pageNumber;
synchronized (this) {
pageNumber = allocatedPages.nextClearBit(0);
if (pageNumber >= PAGE_TABLE_SIZE) {
releaseExtendedMemory(size);
throw new IllegalStateException(
"Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
}
allocatedPages.set(pageNumber);
}
MemoryBlock page = null;
try {
page = memoryManager.extendedMemoryAllocator().allocate(size);
} catch (OutOfMemoryError e) {
logger.error("Failed to allocate a PMem page ({} bytes).", size);
}
page.pageNumber = pageNumber;
page.location = 1;
pageTable[pageNumber] = page;
if (logger.isTraceEnabled()) {
logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
}
return page;

}

public void freePMemPage(MemoryBlock page, MemoryConsumer consumer) {
assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
"Called freePage() on a memory block that has already been freed";
assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
"Called freePage() on a memory block that has already been freed";
assert(allocatedPages.get(page.pageNumber));
pageTable[page.pageNumber] = null;
synchronized (this) {
allocatedPages.clear(page.pageNumber);
}
if (logger.isTraceEnabled()) {
logger.trace("Freed PMem page number {} ({} bytes)", page.pageNumber, page.size());
}
long pageSize = page.size();
// Clear the page number before passing the block to the MemoryAllocator's free().
// Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed
// page has been inappropriately directly freed without calling TMM.freePage().
page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
memoryManager.extendedMemoryAllocator().free(page);
releaseExtendedMemory(pageSize);
}

/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
Expand Down Expand Up @@ -402,6 +476,19 @@ public Object getPage(long pagePlusOffsetAddress) {
}
}

public MemoryBlock getOriginalPage(long pagePlusOffsetAddress) {
if (tungstenMemoryMode == MemoryMode.ON_HEAP) {
final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
final MemoryBlock page = pageTable[pageNumber];
assert (page != null);
assert (page.getBaseObject() != null);
return page;
} else {
return null;
}
}

/**
* Get the offset associated with an address encoded by
* {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
Expand Down Expand Up @@ -439,7 +526,11 @@ public long cleanUpAllAllocatedMemory() {
if (page != null) {
logger.debug("unreleased page: " + page + " in task " + taskAttemptId);
page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
memoryManager.tungstenMemoryAllocator().free(page);
if (page.location == 0){
memoryManager.tungstenMemoryAllocator().free(page);
} else {
memoryManager.extendedMemoryAllocator().free(page);
}
}
}
Arrays.fill(pageTable, null);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package org.apache.spark.util.collection.unsafe.sort;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;

import java.io.Closeable;
import java.util.LinkedList;

public final class PMemReader extends UnsafeSorterIterator implements Closeable {
private int recordLength;
private long keyPrefix;
private int numRecordsRemaining;
private int numRecords;
private LinkedList<MemoryBlock> pMemPages;
private MemoryBlock pMemPage = null;
private int readingPageIndex = 0;
private int readedRecordsInCurrentPage = 0;
private int numRecordsInpage = 0;
private long offset = 0;
private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
public PMemReader(LinkedList<MemoryBlock> pMemPages, int numRecords) {
this.pMemPages = pMemPages;
this.numRecordsRemaining = this.numRecords = numRecords;
}
@Override
public void loadNext() {
assert (readingPageIndex <= pMemPages.size())
: "Illegal state: Pages finished read but hasNext() is true.";
if(pMemPage == null || readedRecordsInCurrentPage == numRecordsInpage) {
// read records from each page
pMemPage = pMemPages.get(readingPageIndex++);
readedRecordsInCurrentPage = 0;
numRecordsInpage = Platform.getInt(null, pMemPage.getBaseOffset());
offset = pMemPage.getBaseOffset() + 4;
}
// record: BaseOffSet, record length, KeyPrefix, record value
keyPrefix = Platform.getLong(null, offset);
offset += 8;
recordLength = Platform.getInt(null, offset);
offset += 4;
if (recordLength > arr.length) {
arr = new byte[recordLength];
baseObject = arr;
}
Platform.copyMemory(null, offset , baseObject, Platform.BYTE_ARRAY_OFFSET, recordLength);
offset += recordLength;
readedRecordsInCurrentPage ++;
numRecordsRemaining --;


}
@Override
public int getNumRecords() {
return numRecords;
}

@Override
public boolean hasNext() {
return (numRecordsRemaining > 0);
}

@Override
public Object getBaseObject() {
return baseObject;
}

@Override
public long getBaseOffset() {
return Platform.BYTE_ARRAY_OFFSET;
}

@Override
public int getRecordLength() {
return recordLength;
}

@Override
public long getKeyPrefix() {
return keyPrefix;
}

@Override
public void close() {
// do nothing here
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package org.apache.spark.util.collection.unsafe.sort;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
import org.apache.spark.unsafe.array.LongArray;

import java.io.Closeable;

public final class PMemReaderForUnsafeExternalSorter extends UnsafeSorterIterator implements Closeable {
private int recordLength;
private long keyPrefix;
private int numRecordsRemaining;
private int numRecords;
private LongArray sortedArray;
private int position = 0;
private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
public PMemReaderForUnsafeExternalSorter(LongArray sortedArray, int numRecords) {
this.sortedArray = sortedArray;
this.numRecordsRemaining = this.numRecords = numRecords;
}
@Override
public void loadNext() {
assert(position < numRecords * 2)
: "Illegal state: Pages finished read but hasNext() is true.";
final long address = sortedArray.get(position);
keyPrefix = sortedArray.get(position + 1);
int uaoSize = UnsafeAlignedOffset.getUaoSize();
recordLength = UnsafeAlignedOffset.getSize(null, address);
if (recordLength > arr.length) {
arr = new byte[recordLength];
baseObject = arr;
}
// System.out.println(numRecordsRemaining);
Platform.copyMemory(null, address + uaoSize , baseObject, Platform.BYTE_ARRAY_OFFSET, recordLength);
numRecordsRemaining --;
position += 2;
}
@Override
public int getNumRecords() {
return numRecords;
}

@Override
public boolean hasNext() {
return (numRecordsRemaining > 0);
}

@Override
public Object getBaseObject() {
return baseObject;
}

@Override
public long getBaseOffset() {
return Platform.BYTE_ARRAY_OFFSET;
}

@Override
public int getRecordLength() {
return recordLength;
}

@Override
public long getKeyPrefix() {
return keyPrefix;
}

@Override
public void close() {
// do nothing here
}
}
Loading

0 comments on commit 9d25542

Please sign in to comment.